Skip to content

Conversation

@JacobHelwig
Copy link
Contributor

@JacobHelwig JacobHelwig commented Jan 8, 2026

What does this PR do?

Updating vLLM LoRA weights raises an error for both FSDP and Megatron workers.

FSDP: error due to not passing arguments for layered summon

Megatron: error due to using incorrect export weights method for class AutoBridge

Test

FSDP

Script:

LEGACY_MODE='disable'
export CUDA_VISIBLE_DEVICES=0,1

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$DATA_PATH/gsm8k/train.parquet \
    data.val_files=$DATA_PATH/gsm8k/test.parquet \
    data.train_batch_size=2 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.shuffle=False \
    actor_rollout_ref.rollout.agent.num_workers=2 \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
    actor_rollout_ref.model.lora_rank=64 \
    actor_rollout_ref.model.lora_alpha=32 \
    actor_rollout_ref.actor.optim.lr=3e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=2 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=True \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.rollout.load_format=safetensors \
    actor_rollout_ref.rollout.layered_summon=True \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console"]' \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2.5_3b_grpo_lora' \
    trainer.n_gpus_per_node=2 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.use_legacy_worker_impl=$LEGACY_MODE \
    trainer.total_epochs=15 \
    actor_rollout_ref.actor.use_torch_compile=False \
    actor_rollout_ref.actor.fsdp_config.use_torch_compile=False \
    trainer.val_before_train=False \
    actor_rollout_ref.rollout.enforce_eager=True \
    actor_rollout_ref.ref.fsdp_config.use_torch_compile=False

Error:

  File "/home/jacob.a.helwig/verl/verl/workers/engine_workers.py", line 572, in wake_up
    await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
  File "/home/jacob.a.helwig/verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py", line 252, in update_weights
    self.inference_engine.worker.add_lora(lora_request)
  File "/home/jacob.a.helwig/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 494, in add_lora
    return self.model_runner.add_lora(lora_request)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/v1/worker/lora_model_runner_mixin.py", line 171, in add_lora
    return self.lora_manager.add_adapter(lora_request)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/lora/worker_manager.py", line 251, in add_adapter
    lora = self._load_adapter(lora_request)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/verl/verl/utils/vllm/utils.py", line 100, in hijack__load_adapter
    lora = self._lora_model_cls.from_lora_tensors(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/lora/models.py", line 135, in from_lora_tensors
    module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/lora/utils.py", line 150, in parse_fine_tuned_lora_name
    raise ValueError(f"{name} is unsupported LoRA weight")
ValueError: model.embed_tokens.weight is unsupported LoRA weight

Megatron

Script:

############################ Quick Config ############################

rollout_name="vllm" # sglang or vllm
project_name='verl_grpo_example_gsm8k_math'
exp_name='qwen2_7b_megatron_lora'

adv_estimator=grpo

max_prompt_length=1024
max_response_length=1024
train_prompt_bsz=2

############################ Paths ############################

gsm8k_train_path=$DATA_PATH/gsm8k/train.parquet
gsm8k_test_path=$DATA_PATH/gsm8k/test.parquet

train_files="['$gsm8k_train_path']"
test_files="['$gsm8k_test_path']"

############################ Parameter Groups ############################

DATA=(
    data.train_files="$train_files"
    data.val_files="$test_files"
    data.max_prompt_length=$max_prompt_length
    data.max_response_length=$max_response_length
    data.train_batch_size=$train_prompt_bsz
    data.filter_overlong_prompts=True
    data.truncation='error'
    data.shuffle=False
)

MODEL=(
    actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct
    actor_rollout_ref.model.lora.rank=64
    actor_rollout_ref.model.lora.alpha=32
    actor_rollout_ref.model.lora.lora_A_init_method=kaiming
    # # Optional: Use canonical LoRA
    # actor_rollout_ref.model.lora.type="canonical_lora"
    # actor_rollout_ref.model.lora.target_modules='["linear_q","linear_k","linear_v","linear_proj","linear_fc1_up","linear_fc1_gate","linear_fc2"]'

    # # Optional: Add dropout to LoRA layers
    # actor_rollout_ref.model.lora.dropout=0.05
    # actor_rollout_ref.model.lora.dropout_position=pre
)

ACTOR=(
    actor_rollout_ref.actor.optim.lr=1e-6
    actor_rollout_ref.actor.ppo_mini_batch_size=2
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=5
    actor_rollout_ref.actor.use_dynamic_bsz=True
    actor_rollout_ref.actor.megatron.use_mbridge=True
    actor_rollout_ref.actor.megatron.vanilla_mbridge=False
    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1
    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=1
    actor_rollout_ref.actor.megatron.sequence_parallel=False
    actor_rollout_ref.actor.use_kl_loss=True
    actor_rollout_ref.actor.kl_loss_coef=0.001
    actor_rollout_ref.actor.kl_loss_type=low_var_kl
    actor_rollout_ref.actor.entropy_coeff=0
    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
)

ROLLOUT=(
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=5
    actor_rollout_ref.rollout.tensor_model_parallel_size=1
    actor_rollout_ref.rollout.name=$rollout_name
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6
    actor_rollout_ref.rollout.n=4
)

REF=(
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=5
    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1
    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=1
    actor_rollout_ref.ref.megatron.sequence_parallel=False
)

ALGORITHM=(
    algorithm.adv_estimator=$adv_estimator
    algorithm.use_kl_in_reward=False
)

TRAINER=(
    trainer.logger='["console"]'
    trainer.project_name=$project_name
    trainer.experiment_name=$exp_name
    trainer.n_gpus_per_node=2
    trainer.nnodes=1
    trainer.save_freq=20
    trainer.test_freq=5
    trainer.total_epochs=15
    trainer.val_before_train=False
    trainer.use_legacy_worker_impl=disable
)

############################ Launch ############################

python3 -m verl.trainer.main_ppo \
    --config-path=config \
    --config-name='ppo_megatron_trainer.yaml' \
    "${DATA[@]}" \
    "${ALGORITHM[@]}" \
    "${MODEL[@]}" \
    "${ROLLOUT[@]}" \
    "${ACTOR[@]}" \
    "${REF[@]}" \
    "${TRAINER[@]}" \
    "$@"

Error:

  File "/home/jacob.a.helwig/verl/verl/trainer/ppo/ray_trainer.py", line 1409, in fit
    gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/verl/verl/experimental/agent_loop/agent_loop.py", line 949, in generate_sequences
    self.wake_up()
  File "/home/jacob.a.helwig/verl/verl/experimental/agent_loop/agent_loop.py", line 997, in wake_up
    self._run_all([replica.wake_up() for replica in self.rollout_replicas])
  File "/home/jacob.a.helwig/verl/verl/experimental/agent_loop/agent_loop.py", line 1011, in _run_all
    asyncio.run(run_all())
  File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/home/jacob.a.helwig/verl/verl/experimental/agent_loop/agent_loop.py", line 1009, in run_all
    await asyncio.gather(*tasks)
  File "/home/jacob.a.helwig/verl/verl/workers/rollout/replica.py", line 200, in wake_up
    await asyncio.gather(*[server.wake_up.remote() for server in self.servers])
  File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/asyncio/tasks.py", line 684, in _wrap_awaitable
    return await awaitable
           ^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(AttributeError): ray::vLLMHttpServer.wake_up() (pid=2329228, ip=10.55.149.115, actor_id=7c53a3aba97589c515d254ba01000000, repr=<verl.workers.rollout.vllm_rollout.vllm_async_server.vLLMHttpServer object at 0x75b0d61d8e00>)
  File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/concurrent/futures/_base.py", line 449, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py", line 556, in wake_up
    await asyncio.gather(*[worker.wake_up.remote() for worker in self.workers])
  File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/asyncio/tasks.py", line 684, in _wrap_awaitable
    return await awaitable
           ^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(AttributeError): ray::WorkerDict.wake_up() (pid=2320299, ip=10.55.149.115, actor_id=bccd2231d5e9ef6dce81c33c01000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x75eadfc7be00>)
  File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/verl/verl/single_controller/ray/base.py", line 848, in async_func
    return await getattr(self.worker_dict[key], name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/verl/verl/single_controller/base/decorator.py", line 462, in async_inner
    return await func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/verl/verl/utils/transferqueue_utils.py", line 319, in dummy_async_inner
    output = await func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/verl/verl/workers/engine_workers.py", line 566, in wake_up
    per_tensor_param, peft_config = self.actor.engine.get_per_tensor_param()
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jacob.a.helwig/verl/verl/workers/engine/megatron/transformer_impl.py", line 541, in get_per_tensor_param
    per_tensor_param = self.bridge.export_weights(self.module)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'AutoBridge' object has no attribute 'export_weights'. Did you mean: 'export_hf_weights'?. Did you mean: '_return_value'?

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug that occurs when updating vLLM LoRA weights with the FSDP Engine, which was caused by a missing layered_summon argument. The fix involves passing the layered_summon and base_sync_done arguments to the get_per_tensor_param method within the wake_up function. The change appears correct and directly solves the issue described. I have included one suggestion to improve code maintainability by using an existing class attribute.

@wuxibin89
Copy link
Collaborator

We may need more comprehensive review lora support for both FSDP and Megatron.


# 1. get per tensor generator from engine, this will load model to gpu
per_tensor_param, peft_config = self.actor.engine.get_per_tensor_param()
per_tensor_param, peft_config = self.actor.engine.get_per_tensor_param(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JacobHelwig Megatron engine get_per_tensor_param doesn't accept layered_summon and base_sync_done, please fix it.
https://github.com/volcengine/verl/blob/main/verl/workers/engine/megatron/transformer_impl.py#L539

Copy link
Contributor Author

@JacobHelwig JacobHelwig Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thank you. I fixed it using kwargs in the Megatron engine and added a fix for LoRA with the Megatron engine (please see updated description).

@JacobHelwig JacobHelwig changed the title [worker] fix: FSDP Engine LoRA layered summon [FSDP worker, Megatron worker] fix: Engine Rollout Worker LoRA Parameter Update Jan 8, 2026
@JacobHelwig JacobHelwig requested a review from wuxibin89 January 8, 2026 22:04
@JacobHelwig
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses two critical bugs related to LoRA parameter updates for FSDP and Megatron workers. For the Megatron worker, the fix correctly distinguishes between vanilla_bridge and the standard Megatron-Bridge to call the appropriate weight export method (export_weights vs. export_hf_weights), resolving an AttributeError. For the FSDP worker, passing layered_summon and base_sync_done to get_per_tensor_param ensures that only LoRA-specific parameters are passed to vLLM, fixing a ValueError. The changes are well-targeted and correctly resolve the underlying issues. The implementation is clean and I have no further suggestions.

@JacobHelwig JacobHelwig changed the title [FSDP worker, Megatron worker] fix: Engine Rollout Worker LoRA Parameter Update [fsdp, megatron] fix: Engine Rollout Worker LoRA Parameter Update Jan 8, 2026
@wuxibin89 wuxibin89 merged commit 94f4654 into volcengine:main Jan 9, 2026
83 of 86 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants