Skip to content

Conversation

@yurekami
Copy link
Contributor

Summary

When parameter offloading is enabled, FSDP requires tensors to be on GPU before accessing state_dict(). The get_actor_weights_info() function was not handling this case, causing:

AssertionError: Expects tensor to be on the compute device cuda:0, was on cpu

Root Cause

The get_actor_weights_info() function calls self._get_actor_params() which internally accesses state_dict(). When parameter offloading is enabled (self._is_offload_param = True), the parameters are on CPU, but FSDP's state_dict() expects them on GPU.

Fix

Added a check for self._is_offload_param and call load_fsdp_model_to_gpu() before accessing parameters. This matches the pattern already used in:

  • sync_rollout_weights() in the same file
  • get_actor_weights_info() in the Megatron worker implementation

Files Changed

  • recipe/fully_async_policy/fsdp_workers.py
  • recipe/one_step_off_policy/fsdp_workers.py

Test Plan

  • Pattern matches existing sync_rollout_weights() implementation
  • Pattern matches Megatron worker implementation
  • Tested with FSDP + parameter offloading enabled

Fixes #4657

🤖 Generated with Claude Code

When parameter offloading is enabled, FSDP requires tensors to be on GPU
before accessing state_dict(). The `get_actor_weights_info()` function
was not handling this case, causing `AssertionError: Expects tensor to
be on the compute device cuda:0, was on cpu`.

This fix adds a check for `self._is_offload_param` and loads parameters
to GPU before accessing them, matching the pattern used in
`sync_rollout_weights()` and the Megatron worker implementation.

Fixed in both:
- recipe/fully_async_policy/fsdp_workers.py
- recipe/one_step_off_policy/fsdp_workers.py

Fixes volcengine#4657

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an AssertionError that occurs when using FSDP with parameter offloading by ensuring the model parameters are on the GPU before accessing state_dict() in get_actor_weights_info. The fix is applied consistently across fully_async_policy and one_step_off_policy workers.

The change is correct for the intended workflow. However, I've raised a concern about the robustness of the implementation. By not offloading the model back to the CPU within the same function, it creates a dependency on the calling context to perform the cleanup, which could lead to potential memory leaks if the function is used differently in the future. I have suggested a refactoring to make the behavior more explicit and safer by default, while still allowing for the intended optimization.

Comment on lines +214 to +217
# When parameter offloading is enabled, we need to load params to GPU
# before accessing state_dict, as FSDP requires tensors on GPU
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation correctly fixes the issue for the intended call sequence by loading the model to the GPU. However, by not offloading it back to the CPU, it makes the function's behavior dependent on the subsequent call to sync_rollout_weights for cleanup. This can be brittle and may lead to unexpected GPU memory usage if get_actor_weights_info is ever called in a different context. This could potentially cause Out-Of-Memory errors.

A more robust approach would be to make this optimization explicit. I recommend adding a keep_on_gpu: bool = False parameter to the function signature. This provides a safe default (offloading back to CPU) while allowing the caller to request the optimized behavior.

Since I cannot suggest changes to the function signature directly, here is an illustration of the proposed change for your consideration:

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self, keep_on_gpu: bool = False):
    assert self._is_actor
    if hasattr(self, "_weights_info"):
        return self._weights_info

    if self._is_offload_param:
        load_fsdp_model_to_gpu(self.actor_module_fsdp)

    try:
        if fsdp_version(self.actor_module_fsdp) == 1:
            from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType

            FSDP.set_state_dict_type(
                self.actor_module_fsdp,
                state_dict_type=StateDictType.SHARDED_STATE_DICT,
                state_dict_config=ShardedStateDictConfig(),
            )
        params = self._get_actor_params()
        ret = []
        for key, tensor in params.items():
            ret.append((key, tensor.size(), tensor.dtype))
        self._weights_info = ret
        return ret
    finally:
        if self._is_offload_param and not keep_on_gpu:
            offload_fsdp_model_to_cpu(self.actor_module_fsdp)

This would require updating the caller to pass keep_on_gpu=True when this optimization is desired.

Comment on lines +149 to +152
# When parameter offloading is enabled, we need to load params to GPU
# before accessing state_dict, as FSDP requires tensors on GPU
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation correctly fixes the issue for the intended call sequence by loading the model to the GPU. However, by not offloading it back to the CPU, it makes the function's behavior dependent on the subsequent call to sync_rollout_weights for cleanup. This can be brittle and may lead to unexpected GPU memory usage if get_actor_weights_info is ever called in a different context. This could potentially cause Out-Of-Memory errors.

A more robust approach would be to make this optimization explicit. I recommend adding a keep_on_gpu: bool = False parameter to the function signature. This provides a safe default (offloading back to CPU) while allowing the caller to request the optimized behavior.

Since I cannot suggest changes to the function signature directly, here is an illustration of the proposed change for your consideration:

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self, keep_on_gpu: bool = False):
    assert self._is_actor
    if hasattr(self, "_weights_info"):
        return self._weights_info

    if self._is_offload_param:
        load_fsdp_model_to_gpu(self.actor_module_fsdp)

    try:
        if fsdp_version(self.actor_module_fsdp) == 1:
            from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType

            FSDP.set_state_dict_type(
                self.actor_module_fsdp,
                state_dict_type=StateDictType.SHARDED_STATE_DICT,
                state_dict_config=ShardedStateDictConfig(),
            )
        params = self._get_actor_params()
        ret = []
        for key, tensor in params.items():
            ret.append((key, tensor.size(), tensor.dtype))
        self._weights_info = ret
        return ret
    finally:
        if self._is_offload_param and not keep_on_gpu:
            offload_fsdp_model_to_cpu(self.actor_module_fsdp)

This would require updating the caller to pass keep_on_gpu=True when this optimization is desired.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] AssertionError: Expects tensor to be on the compute device cuda:0, was on cpu

2 participants