@@ -189,7 +189,7 @@ def offload_fsdp_optimizer(optimizer):
189189 for param in param_group ["params" ]:
190190 state = optimizer .state [param ]
191191 for key , value in state .items ():
192- if isinstance (value , ( torch .Tensor , DTensor ) ):
192+ if isinstance (value , torch .Tensor ):
193193 state [key ] = value .to ("cpu" , non_blocking = True )
194194
195195
@@ -201,7 +201,7 @@ def load_fsdp_optimizer(optimizer, device_id):
201201 for param in param_group ["params" ]:
202202 state = optimizer .state [param ]
203203 for key , value in state .items ():
204- if isinstance (value , ( torch .Tensor , DTensor ) ):
204+ if isinstance (value , torch .Tensor ):
205205 state [key ] = value .to (device_id , non_blocking = True )
206206
207207
@@ -427,7 +427,7 @@ def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_
427427
428428 # rotary_emb is not in state_dict, so we need to broadcast it manually
429429 for name , buf in model .named_buffers ():
430- dist .broadcast (buf , src = 0 , group = device_mesh . get_group ())
430+ dist .broadcast (buf , src = 0 )
431431
432432 if cpu_offload :
433433 model .to ('cpu' , non_blocking = True )
@@ -451,7 +451,8 @@ def apply_fsdp2(model, fsdp_kwargs, config):
451451
452452 modules = []
453453 for name , module in model .named_modules ():
454- if module .__class__ .__name__ in fsdp_transformer_layer_cls_to_wrap or isinstance (module , nn .Embedding ):
454+ if module .__class__ .__name__ in fsdp_transformer_layer_cls_to_wrap or \
455+ (isinstance (module , nn .Embedding ) and not model .config .tie_word_embeddings ):
455456 modules .append (module )
456457
457458 for idx , module in enumerate (modules ):
0 commit comments