Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions autoparallel/apply_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,31 @@
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard # noqa
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._pytree import tree_flatten, tree_map_only


def my_redistribute_local_tensor(arg, curr_spec, tgt_spec):
if curr_spec.placements == (Shard(0), Shard(0)) and tgt_spec.placements == (
Replicate(),
Shard(0),
):
# TODO: double-check in which cases this is valid
x = curr_spec.placements[0]._to_replicate_tensor(
arg, curr_spec.mesh, 0, curr_spec.shape
)
elif curr_spec.placements == (Partial(), Shard(0)) and tgt_spec.placements == (
Shard(0),
Shard(0),
):
x = curr_spec.placements[0]._reduce_shard_value(
arg, curr_spec.mesh, 0, tgt_spec.placements[0]
)
# if curr_spec.placements == (Shard(0), Shard(0)) and tgt_spec.placements == (
# Replicate(),
# Shard(0),
# ):
# # TODO: double-check in which cases this is valid
# x = curr_spec.placements[0]._to_replicate_tensor(
# arg, curr_spec.mesh, 0, curr_spec.shape
# )
# elif curr_spec.placements == (Partial(), Shard(0)) and tgt_spec.placements == (
# Shard(0),
# Shard(0),
# ):
# x = curr_spec.placements[0]._reduce_shard_value(
# arg, curr_spec.mesh, 0, tgt_spec.placements[0]
# )
# elif curr_spec.placements == (Partial(), Shard(1)) and tgt_spec.placements == (Replicate(), Shard(1)):
# from IPython import embed; embed(); sys.sdf
else:
x = redistribute_local_tensor(arg, curr_spec, tgt_spec)
# else:
x = redistribute_local_tensor(arg, curr_spec, tgt_spec)
return x


Expand Down