Skip to content

Commit 783777d

Browse files
committed
Fix DiLoCo with DTensor
1 parent b84c5a6 commit 783777d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchft/local_sgd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _restore_parameters(self) -> None:
249249
# we averaged the local version of the tensor so need to copy it back as a DTensor
250250
p.data.copy_(
251251
DTensor.from_local(
252-
self.original_parameters[name], p.device_mesh, p.placements
252+
self.original_parameters[name], p.device_mesh, p.placements, shape=p.shape, stride=p.stride()
253253
),
254254
non_blocking=False,
255255
)

0 commit comments

Comments
 (0)