Skip to content

Commit 0064095

Browse files
authored
Fix the FSDP extension to make load_state_dict works for 2D. (#570)
1 parent 3cb9137 commit 0064095

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

spmd/tensor/parallel/fsdp.py

+1
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def _pre_load_state_dict(
309309
if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
310310
inner_tensor = cast(ShardedTensor, shards[0].tensor)
311311
shards = inner_tensor.local_shards()
312+
tensor = inner_tensor
312313

313314
return (tensor, shards if len(shards) > 0 else [])
314315

0 commit comments

Comments
 (0)