We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3cb9137 commit 0064095Copy full SHA for 0064095
spmd/tensor/parallel/fsdp.py
@@ -309,6 +309,7 @@ def _pre_load_state_dict(
309
if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
310
inner_tensor = cast(ShardedTensor, shards[0].tensor)
311
shards = inner_tensor.local_shards()
312
+ tensor = inner_tensor
313
314
return (tensor, shards if len(shards) > 0 else [])
315
0 commit comments