Skip to content

Commit 48cebc9

Browse files
authored
[FSDP] Remove redundant GPU memory restore and improve code style (THUDM#658)
1 parent dfbf5f3 commit 48cebc9

File tree

2 files changed

+50
-53
lines changed

2 files changed

+50
-53
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from argparse import Namespace
2-
from collections.abc import Mapping
32
from contextlib import nullcontext
43
from itertools import accumulate
54

@@ -669,65 +668,20 @@ def update_gpu_params_dict(self, params_dict: dict[str, torch.Tensor]) -> None:
669668
670669
Parameters:
671670
params_dict: Source mapping from parameter names to CPU tensors.
672-
"""
673-
self._load_cpu_state_dict(params_dict)
674-
torch.cuda.synchronize()
675-
676-
def load_ref_model(self, ref_load_path: str | None) -> None:
677-
"""Load reference model weights once and cache them on CPU.
678671
679-
Parameters:
680-
ref_load_path: Path to a directory containing a HF checkpoint. If
681-
None, a ValueError is raised.
672+
Note:
673+
This method handles both regular Tensors and DTensors. For DTensors,
674+
it properly distributes the full tensor according to FSDP sharding.
682675
"""
683-
if ref_load_path is None:
684-
raise ValueError("ref_load_path must be provided when loading reference model")
685-
686-
current_weights = {}
687-
self.update_cpu_params_dict(current_weights)
688-
689-
try:
690-
import os
691-
692-
if os.path.isdir(ref_load_path):
693-
temp_ref_model = AutoModelForCausalLM.from_pretrained(
694-
ref_load_path,
695-
trust_remote_code=True,
696-
torch_dtype=torch.bfloat16,
697-
device_map="cpu",
698-
)
699-
700-
ref_state_dict = temp_ref_model.state_dict()
701-
self.weights["ref"] = {}
702-
703-
for name, tensor in ref_state_dict.items():
704-
actor_tensor = current_weights.get(name)
705-
target_dtype = actor_tensor.dtype if actor_tensor is not None else tensor.dtype
706-
cpu_tensor = tensor.detach().to(device="cpu", dtype=target_dtype, copy=True)
707-
self.weights["ref"][name] = cpu_tensor.pin_memory()
708-
709-
del temp_ref_model
710-
torch.cuda.empty_cache()
711-
else:
712-
raise NotImplementedError(f"Loading from checkpoint file {ref_load_path} not yet implemented")
713-
714-
print("Reference model parameters loaded and stored in CPU memory")
715-
716-
finally:
717-
self.update_gpu_params_dict(current_weights)
718-
719-
@torch.no_grad()
720-
def _load_cpu_state_dict(self, full_state_dict: Mapping[str, torch.Tensor]) -> None:
721-
"""Load a CPU full-state dict into the model, handling DTensor shards."""
722-
676+
# Cache parameter and buffer maps for efficiency
723677
if not hasattr(self, "_fsdp_param_map"):
724678
self._fsdp_param_map = dict(self.model.named_parameters())
725679
self._fsdp_buffer_map = dict(self.model.named_buffers())
726680

727681
param_map = self._fsdp_param_map
728682
buffer_map = self._fsdp_buffer_map
729683

730-
for name, src in full_state_dict.items():
684+
for name, src in params_dict.items():
731685
if not torch.is_tensor(src):
732686
continue
733687

@@ -753,8 +707,50 @@ def _load_cpu_state_dict(self, full_state_dict: Mapping[str, torch.Tensor]) -> N
753707
)
754708
dst_tensor.copy_(distributed)
755709
else:
710+
# Regular tensor: just move to GPU
756711
dst_tensor.copy_(src_tensor.to(device=dst_tensor.device, non_blocking=True))
757712

713+
torch.cuda.synchronize()
714+
715+
def load_ref_model(self, ref_load_path: str | None) -> None:
716+
"""Load reference model weights once and cache them on CPU.
717+
718+
Parameters:
719+
ref_load_path: Path to a directory containing a HF checkpoint. If
720+
None, a ValueError is raised.
721+
"""
722+
if ref_load_path is None:
723+
raise ValueError("ref_load_path must be provided when loading reference model")
724+
725+
import os
726+
727+
if os.path.isdir(ref_load_path):
728+
# Get actor weights for dtype matching
729+
actor_weights = {}
730+
self.update_cpu_params_dict(actor_weights)
731+
732+
temp_ref_model = AutoModelForCausalLM.from_pretrained(
733+
ref_load_path,
734+
trust_remote_code=True,
735+
torch_dtype=torch.bfloat16,
736+
device_map="cpu",
737+
)
738+
ref_state_dict = temp_ref_model.state_dict()
739+
self.weights["ref"] = {}
740+
741+
for name, tensor in ref_state_dict.items():
742+
actor_tensor = actor_weights.get(name)
743+
target_dtype = actor_tensor.dtype if actor_tensor is not None else tensor.dtype
744+
cpu_tensor = tensor.detach().to(device="cpu", dtype=target_dtype, copy=True)
745+
self.weights["ref"][name] = cpu_tensor.pin_memory()
746+
747+
del temp_ref_model
748+
torch.cuda.empty_cache()
749+
else:
750+
raise NotImplementedError(f"Loading from checkpoint file {ref_load_path} not yet implemented")
751+
752+
print("Reference model parameters loaded and stored in CPU memory")
753+
758754

759755
def selective_log_softmax_raw(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
760756
"""Fused version of the common `log_softmax -> gather` operation.

tests/test_qwen3-0.6B_fsdp_distributed.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ set -ex
1515
# will prevent ray from buffering stdout/stderr
1616
export PYTHONBUFFERED=16
1717

18+
1819
CKPT_ARGS=(
1920
--hf-checkpoint /root/Qwen3-0.6B
2021
--ref-load /root/Qwen3-0.6B
@@ -80,9 +81,9 @@ ray job submit --address="http://127.0.0.1:8265" \
8081
}' \
8182
-- python3 train.py \
8283
--actor-num-nodes 1 \
83-
--actor-num-gpus-per-node 4 \
84-
--colocate \
84+
--actor-num-gpus-per-node 2 \
8585
--train-backend fsdp \
86+
--rollout-num-gpus 2 \
8687
${CKPT_ARGS[@]} \
8788
${ROLLOUT_ARGS[@]} \
8889
${OPTIMIZER_ARGS[@]} \

0 commit comments

Comments
 (0)