11from argparse import Namespace
2- from collections .abc import Mapping
32from contextlib import nullcontext
43from itertools import accumulate
54
@@ -669,65 +668,20 @@ def update_gpu_params_dict(self, params_dict: dict[str, torch.Tensor]) -> None:
669668
670669 Parameters:
671670 params_dict: Source mapping from parameter names to CPU tensors.
672- """
673- self ._load_cpu_state_dict (params_dict )
674- torch .cuda .synchronize ()
675-
676- def load_ref_model (self , ref_load_path : str | None ) -> None :
677- """Load reference model weights once and cache them on CPU.
678671
679- Parameters :
680- ref_load_path: Path to a directory containing a HF checkpoint. If
681- None, a ValueError is raised .
672+ Note :
673+ This method handles both regular Tensors and DTensors. For DTensors,
674+ it properly distributes the full tensor according to FSDP sharding .
682675 """
683- if ref_load_path is None :
684- raise ValueError ("ref_load_path must be provided when loading reference model" )
685-
686- current_weights = {}
687- self .update_cpu_params_dict (current_weights )
688-
689- try :
690- import os
691-
692- if os .path .isdir (ref_load_path ):
693- temp_ref_model = AutoModelForCausalLM .from_pretrained (
694- ref_load_path ,
695- trust_remote_code = True ,
696- torch_dtype = torch .bfloat16 ,
697- device_map = "cpu" ,
698- )
699-
700- ref_state_dict = temp_ref_model .state_dict ()
701- self .weights ["ref" ] = {}
702-
703- for name , tensor in ref_state_dict .items ():
704- actor_tensor = current_weights .get (name )
705- target_dtype = actor_tensor .dtype if actor_tensor is not None else tensor .dtype
706- cpu_tensor = tensor .detach ().to (device = "cpu" , dtype = target_dtype , copy = True )
707- self .weights ["ref" ][name ] = cpu_tensor .pin_memory ()
708-
709- del temp_ref_model
710- torch .cuda .empty_cache ()
711- else :
712- raise NotImplementedError (f"Loading from checkpoint file { ref_load_path } not yet implemented" )
713-
714- print ("Reference model parameters loaded and stored in CPU memory" )
715-
716- finally :
717- self .update_gpu_params_dict (current_weights )
718-
719- @torch .no_grad ()
720- def _load_cpu_state_dict (self , full_state_dict : Mapping [str , torch .Tensor ]) -> None :
721- """Load a CPU full-state dict into the model, handling DTensor shards."""
722-
676+ # Cache parameter and buffer maps for efficiency
723677 if not hasattr (self , "_fsdp_param_map" ):
724678 self ._fsdp_param_map = dict (self .model .named_parameters ())
725679 self ._fsdp_buffer_map = dict (self .model .named_buffers ())
726680
727681 param_map = self ._fsdp_param_map
728682 buffer_map = self ._fsdp_buffer_map
729683
730- for name , src in full_state_dict .items ():
684+ for name , src in params_dict .items ():
731685 if not torch .is_tensor (src ):
732686 continue
733687
@@ -753,8 +707,50 @@ def _load_cpu_state_dict(self, full_state_dict: Mapping[str, torch.Tensor]) -> N
753707 )
754708 dst_tensor .copy_ (distributed )
755709 else :
710+ # Regular tensor: just move to GPU
756711 dst_tensor .copy_ (src_tensor .to (device = dst_tensor .device , non_blocking = True ))
757712
713+ torch .cuda .synchronize ()
714+
715+ def load_ref_model (self , ref_load_path : str | None ) -> None :
716+ """Load reference model weights once and cache them on CPU.
717+
718+ Parameters:
719+ ref_load_path: Path to a directory containing a HF checkpoint. If
720+ None, a ValueError is raised.
721+ """
722+ if ref_load_path is None :
723+ raise ValueError ("ref_load_path must be provided when loading reference model" )
724+
725+ import os
726+
727+ if os .path .isdir (ref_load_path ):
728+ # Get actor weights for dtype matching
729+ actor_weights = {}
730+ self .update_cpu_params_dict (actor_weights )
731+
732+ temp_ref_model = AutoModelForCausalLM .from_pretrained (
733+ ref_load_path ,
734+ trust_remote_code = True ,
735+ torch_dtype = torch .bfloat16 ,
736+ device_map = "cpu" ,
737+ )
738+ ref_state_dict = temp_ref_model .state_dict ()
739+ self .weights ["ref" ] = {}
740+
741+ for name , tensor in ref_state_dict .items ():
742+ actor_tensor = actor_weights .get (name )
743+ target_dtype = actor_tensor .dtype if actor_tensor is not None else tensor .dtype
744+ cpu_tensor = tensor .detach ().to (device = "cpu" , dtype = target_dtype , copy = True )
745+ self .weights ["ref" ][name ] = cpu_tensor .pin_memory ()
746+
747+ del temp_ref_model
748+ torch .cuda .empty_cache ()
749+ else :
750+ raise NotImplementedError (f"Loading from checkpoint file { ref_load_path } not yet implemented" )
751+
752+ print ("Reference model parameters loaded and stored in CPU memory" )
753+
758754
759755def selective_log_softmax_raw (logits : torch .Tensor , input_ids : torch .Tensor ) -> torch .Tensor :
760756 """Fused version of the common `log_softmax -> gather` operation.
0 commit comments