Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,12 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
else:
return {}

def get_per_tensor_param(self):
def get_per_tensor_param(self, **kwargs):
load_megatron_model_to_gpu(self.module, load_grad=False)
per_tensor_param = self.bridge.export_weights(self.module)
if self.vanilla_bridge:
per_tensor_param = self.bridge.export_weights(self.module)
else:
per_tensor_param = self.bridge.export_hf_weights(self.module)
# TODO: support megatron LoRA
return per_tensor_param, None

Expand Down
4 changes: 3 additions & 1 deletion verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,9 @@ async def wake_up(self):
set_expandable_segments(False)

# 1. get per tensor generator from engine, this will load model to gpu
per_tensor_param, peft_config = self.actor.engine.get_per_tensor_param()
per_tensor_param, peft_config = self.actor.engine.get_per_tensor_param(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JacobHelwig Megatron engine get_per_tensor_param doesn't accept layered_summon and base_sync_done, please fix it.
https://github.com/volcengine/verl/blob/main/verl/workers/engine/megatron/transformer_impl.py#L539

Copy link
Contributor Author

@JacobHelwig JacobHelwig Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thank you. I fixed it using kwargs in the Megatron engine and added a fix for LoRA with the Megatron engine (please see updated description).

layered_summon=self.layered_summon, base_sync_done=self.base_sync_done
)

# 2. resume weights and update weights
if self.config.rollout.free_cache_engine:
Expand Down