-
Notifications
You must be signed in to change notification settings - Fork 3k
fix(fsdp): handle param offloading in get_actor_weights_info #4726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix(fsdp): handle param offloading in get_actor_weights_info #4726
Conversation
When parameter offloading is enabled, FSDP requires tensors to be on GPU before accessing state_dict(). The `get_actor_weights_info()` function was not handling this case, causing `AssertionError: Expects tensor to be on the compute device cuda:0, was on cpu`. This fix adds a check for `self._is_offload_param` and loads parameters to GPU before accessing them, matching the pattern used in `sync_rollout_weights()` and the Megatron worker implementation. Fixed in both: - recipe/fully_async_policy/fsdp_workers.py - recipe/one_step_off_policy/fsdp_workers.py Fixes volcengine#4657 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an AssertionError that occurs when using FSDP with parameter offloading by ensuring the model parameters are on the GPU before accessing state_dict() in get_actor_weights_info. The fix is applied consistently across fully_async_policy and one_step_off_policy workers.
The change is correct for the intended workflow. However, I've raised a concern about the robustness of the implementation. By not offloading the model back to the CPU within the same function, it creates a dependency on the calling context to perform the cleanup, which could lead to potential memory leaks if the function is used differently in the future. I have suggested a refactoring to make the behavior more explicit and safer by default, while still allowing for the intended optimization.
| # When parameter offloading is enabled, we need to load params to GPU | ||
| # before accessing state_dict, as FSDP requires tensors on GPU | ||
| if self._is_offload_param: | ||
| load_fsdp_model_to_gpu(self.actor_module_fsdp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation correctly fixes the issue for the intended call sequence by loading the model to the GPU. However, by not offloading it back to the CPU, it makes the function's behavior dependent on the subsequent call to sync_rollout_weights for cleanup. This can be brittle and may lead to unexpected GPU memory usage if get_actor_weights_info is ever called in a different context. This could potentially cause Out-Of-Memory errors.
A more robust approach would be to make this optimization explicit. I recommend adding a keep_on_gpu: bool = False parameter to the function signature. This provides a safe default (offloading back to CPU) while allowing the caller to request the optimized behavior.
Since I cannot suggest changes to the function signature directly, here is an illustration of the proposed change for your consideration:
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self, keep_on_gpu: bool = False):
assert self._is_actor
if hasattr(self, "_weights_info"):
return self._weights_info
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
try:
if fsdp_version(self.actor_module_fsdp) == 1:
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
FSDP.set_state_dict_type(
self.actor_module_fsdp,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
params = self._get_actor_params()
ret = []
for key, tensor in params.items():
ret.append((key, tensor.size(), tensor.dtype))
self._weights_info = ret
return ret
finally:
if self._is_offload_param and not keep_on_gpu:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)This would require updating the caller to pass keep_on_gpu=True when this optimization is desired.
| # When parameter offloading is enabled, we need to load params to GPU | ||
| # before accessing state_dict, as FSDP requires tensors on GPU | ||
| if self._is_offload_param: | ||
| load_fsdp_model_to_gpu(self.actor_module_fsdp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation correctly fixes the issue for the intended call sequence by loading the model to the GPU. However, by not offloading it back to the CPU, it makes the function's behavior dependent on the subsequent call to sync_rollout_weights for cleanup. This can be brittle and may lead to unexpected GPU memory usage if get_actor_weights_info is ever called in a different context. This could potentially cause Out-Of-Memory errors.
A more robust approach would be to make this optimization explicit. I recommend adding a keep_on_gpu: bool = False parameter to the function signature. This provides a safe default (offloading back to CPU) while allowing the caller to request the optimized behavior.
Since I cannot suggest changes to the function signature directly, here is an illustration of the proposed change for your consideration:
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self, keep_on_gpu: bool = False):
assert self._is_actor
if hasattr(self, "_weights_info"):
return self._weights_info
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
try:
if fsdp_version(self.actor_module_fsdp) == 1:
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
FSDP.set_state_dict_type(
self.actor_module_fsdp,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
params = self._get_actor_params()
ret = []
for key, tensor in params.items():
ret.append((key, tensor.size(), tensor.dtype))
self._weights_info = ret
return ret
finally:
if self._is_offload_param and not keep_on_gpu:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)This would require updating the caller to pass keep_on_gpu=True when this optimization is desired.
Summary
When parameter offloading is enabled, FSDP requires tensors to be on GPU before accessing
state_dict(). Theget_actor_weights_info()function was not handling this case, causing:Root Cause
The
get_actor_weights_info()function callsself._get_actor_params()which internally accessesstate_dict(). When parameter offloading is enabled (self._is_offload_param = True), the parameters are on CPU, but FSDP'sstate_dict()expects them on GPU.Fix
Added a check for
self._is_offload_paramand callload_fsdp_model_to_gpu()before accessing parameters. This matches the pattern already used in:sync_rollout_weights()in the same fileget_actor_weights_info()in the Megatron worker implementationFiles Changed
recipe/fully_async_policy/fsdp_workers.pyrecipe/one_step_off_policy/fsdp_workers.pyTest Plan
sync_rollout_weights()implementationFixes #4657
🤖 Generated with Claude Code