Skip to content

Commit

Permalink
actor ref lora 2in1
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayi-Pan committed Feb 5, 2025
1 parent a0be6d9 commit cbc9e5d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
13 changes: 10 additions & 3 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def __init__(self,
self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls

# if ref_in_actor is True, the reference policy will be actor without lora applied
self.ref_in_actor = config.actor_rollout_ref.model.get('lora_rank', 0) > 0


# define KL control
if self.use_reference_policy:
if config.algorithm.kl_ctrl.type == 'fixed':
Expand Down Expand Up @@ -474,7 +478,7 @@ def init_workers(self):
raise NotImplementedError

# create reference policy if needed
if self.use_reference_policy:
if self.use_reference_policy and not self.ref_in_actor:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],
config=self.config.actor_rollout_ref,
Expand Down Expand Up @@ -506,7 +510,7 @@ def init_workers(self):
self.critic_wg = all_wg['critic']
self.critic_wg.init_model()

if self.use_reference_policy:
if self.use_reference_policy and not self.ref_in_actor:
self.ref_policy_wg = all_wg['ref']
self.ref_policy_wg.init_model()

Expand Down Expand Up @@ -614,7 +618,10 @@ def fit(self):
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)

# compute values
Expand Down
9 changes: 5 additions & 4 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,12 +525,13 @@ def compute_log_prob(self, data: DataProto, no_lora=False):

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
if self._is_lora:
# TODO
pass
# if _is_lora, actor without lora applied is the ref
# return self.compute_log_prob(data, no_lora=True)
data = self.compute_log_prob(data, no_lora=True)
# this old_log_probs is in fact ref_log_prob
data = DataProto.from_dict(tensors={'ref_log_prob': data.batch['old_log_probs']})
return data
assert self._is_ref
# else:
# otherwise, the class have a standalone ref model
data = data.to('cuda')
Expand Down

0 comments on commit cbc9e5d

Please sign in to comment.