Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1468,24 +1468,22 @@ void InitXlaModuleBindings(py::module m) {
const py::list& replication_groups, int sharding_type,
bool minibatch) {
xla::Shape global_shape =
CreateComputationShapeFromTensor(tensor, nullptr);
if (minibatch) {
int num_local_devices =
runtime::GetComputationClientOrDie()->GetLocalDevices().size();
int num_global_devices =
runtime::GetComputationClientOrDie()->GetAllDevices().size();
XLA_CHECK(tile_assignment.size() == num_global_devices)
<< "Minibatch sharding only supports sharding along the batch "
"dimension";
int batch_dim_shape =
tensor.sizes()[0] * num_global_devices / num_local_devices;
global_shape.set_dimensions(0, batch_dim_shape);
}
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
return std::make_shared<XLATensor::ShardingSpec>(
ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type)),
global_shape, minibatch);
})
.def_init([](at::Tensor tensor, const py::list& dims,
const py::list& reshape_dims, const py::list& transpose_perm,
bool minibatch) {
xla::Shape global_shape =
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
return std::make_shared<XLATensor::ShardingSpec>(
ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
transpose_perm),
global_shape, minibatch);
});

// Define the _XLAC.IrValue class.
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -882,4 +882,20 @@ bool ShardingUtil::GetAutoSharding() {
}
return use_auto_sharding;
}

xla::Shape ShardingUtil::GetAdjustedGlobalShape(const at::Tensor& tensor,
bool minibatch) {
xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr);
if (minibatch) {
int num_local_devices =
runtime::GetComputationClientOrDie()->GetLocalDevices().size();
int num_global_devices =
runtime::GetComputationClientOrDie()->GetAllDevices().size();
int batch_dim_shape =
tensor.sizes()[0] * num_global_devices / num_local_devices;
global_shape.set_dimensions(0, batch_dim_shape);
}
return global_shape;
}

} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ class ShardingUtil {

static void SetAutoSharding();
static bool GetAutoSharding();

static xla::Shape GetAdjustedGlobalShape(const at::Tensor& tensor,
bool minibatch);
};

} // namespace torch_xla
Expand Down
35 changes: 28 additions & 7 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,10 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec):

@functools.lru_cache(maxsize=None)
def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):
"""
Returns the appropriate dims, reshape_dims, and transpose_perm for the given partition spec.
"""
partition_spec = _translate_named_partition_spec(self, partition_spec)
self._validate_translated_partition_spec(partition_spec)

# 1. Calculate the initial part of dims based on the partition_spec.
dims = []
used_axes = OrderedDict()
for axis in partition_spec:
Expand All @@ -175,14 +173,22 @@ def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):
dims.append(self.mesh_shape[axis])
used_axes[axis] = True
else:
# Replicated mesh axis
dims.append(1)

transpose_perm = [k for k in used_axes.keys()]
# 2. If the product of dims is less than the total number of devices,
# append the sizes of the unused mesh axes.
if math.prod(dims) < math.prod(self.mesh_shape):
for i in range(len(self.mesh_shape)):
if i not in used_axes:
dims.append(self.mesh_shape[i])

# 3. Calculate transpose_perm (sharded axes first, then unused axes).
transpose_perm = list(used_axes.keys())
for i in range(len(self.mesh_shape)):
if i not in used_axes:
dims.append(self.mesh_shape[i])
transpose_perm.append(i)

# 4. reshape_dims is always the physical mesh shape.
reshape_dims = list(self.mesh_shape)

return dims, reshape_dims, transpose_perm
Expand Down Expand Up @@ -591,6 +597,11 @@ def _mark_manual_sharding(
return wrap_as_sharded_tensor(t)


def _use_shlo_to_shardy() -> bool:
return os.environ.get("CONVERT_SHLO_TO_SHARDY",
"").lower() in ("1", "true", "yes")


def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
partition_spec: PartitionSpec,
*,
Expand Down Expand Up @@ -710,7 +721,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
t.shard_(NamedSharding(jmesh, P(*partition_spec)))
return t

if os.environ.get('CONVERT_SHLO_TO_SHARDY', False):
if _use_shlo_to_shardy():
op_sharding = mesh.get_op_sharding_v2(partition_spec)
else:
op_sharding = mesh.get_op_sharding(partition_spec)
Expand Down Expand Up @@ -869,6 +880,9 @@ def __post_init__(self):
self._group_assignment, self._replication_groups = _get_group_assignment(
self._sharding_type, tile_assignment, len(partition_spec),
replicate_dims)
if _use_shlo_to_shardy():
self.dims, self.reshape_dims, self.transpose_dims = mesh._get_op_sharding_args_v2(
partition_spec)

def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
"""
Expand All @@ -877,6 +891,13 @@ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
"""
if not self.can_apply(t):
return None

if _use_shlo_to_shardy():
# Convert to Shardy spec if the environment variable is set.
return torch_xla._XLAC.XlaShardingSpec(t, self.dims, self.reshape_dims,
self.transpose_dims,
self.minibatch)

return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment,
self._group_assignment,
self._replication_groups,
Expand Down