Skip to content

Commit 33faad6

Browse files
lxg2015lixiaoguang12
authored andcommitted
remove DTensor check for fsdp2
1 parent daf49a8 commit 33faad6

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

verl/utils/fsdp_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def offload_fsdp_optimizer(optimizer):
189189
for param in param_group["params"]:
190190
state = optimizer.state[param]
191191
for key, value in state.items():
192-
if isinstance(value, (torch.Tensor, DTensor)):
192+
if isinstance(value, torch.Tensor):
193193
state[key] = value.to("cpu", non_blocking=True)
194194

195195

@@ -201,7 +201,7 @@ def load_fsdp_optimizer(optimizer, device_id):
201201
for param in param_group["params"]:
202202
state = optimizer.state[param]
203203
for key, value in state.items():
204-
if isinstance(value, (torch.Tensor, DTensor)):
204+
if isinstance(value, torch.Tensor):
205205
state[key] = value.to(device_id, non_blocking=True)
206206

207207

@@ -427,7 +427,7 @@ def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_
427427

428428
# rotary_emb is not in state_dict, so we need to broadcast it manually
429429
for name, buf in model.named_buffers():
430-
dist.broadcast(buf, src=0, group=device_mesh.get_group())
430+
dist.broadcast(buf, src=0)
431431

432432
if cpu_offload:
433433
model.to('cpu', non_blocking=True)
@@ -451,7 +451,8 @@ def apply_fsdp2(model, fsdp_kwargs, config):
451451

452452
modules = []
453453
for name, module in model.named_modules():
454-
if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or isinstance(module, nn.Embedding):
454+
if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or \
455+
(isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings):
455456
modules.append(module)
456457

457458
for idx, module in enumerate(modules):

0 commit comments

Comments
 (0)