Skip to content

Commit 95be731

Browse files
committed
Fix visualize_tensor_sharding function for V2 shardings
1 parent 4ca84d9 commit 95be731

File tree

4 files changed

+150
-33
lines changed

4 files changed

+150
-33
lines changed

test/spmd/test_spmd_debugging.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch_xla.distributed.spmd as xs
1818
from torch_xla.distributed.spmd import XLAShardedTensor
1919
from torch_xla.distributed.spmd import Mesh
20+
from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str
2021

2122
import test_xla_sharding_base
2223

@@ -828,6 +829,77 @@ def test_multi_host_replicated_cpu(self):
828829
fake_output = fake_capture.get()
829830
assert output == fake_output
830831

832+
833+
class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest):
834+
835+
@classmethod
836+
def setUpClass(cls):
837+
super().setUpClass()
838+
os.environ["CONVERT_SHLO_TO_SHARDY"] = "1"
839+
840+
def run_test(self):
841+
mesh = self._get_mesh(self.device_mesh_shape)
842+
t = torch.randn(self.tensor_shape).to(torch_xla.device())
843+
xs.mark_sharding(t, mesh, self.partition_spec)
844+
actual_str = construct_v1_sharding_str(t)
845+
self.assertEqual(self.expected_str, actual_str)
846+
847+
def test_tiled_sharding(self):
848+
self.device_mesh_shape = (1, self.n_devices)
849+
self.tensor_shape = (1, 128)
850+
self.partition_spec = (0, 1)
851+
self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
852+
[str(i) for i in range(self.n_devices)]))
853+
self.run_test()
854+
855+
@unittest.skipIf(xr.global_runtime_device_count() < 2,
856+
f"Requires at least 2 devices.")
857+
def test_tupled_tiled_sharding(self):
858+
self.device_mesh_shape = (2, self.n_devices // 2)
859+
self.tensor_shape = (16,)
860+
self.partition_spec = ((0, 1),)
861+
self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join(
862+
str(x) for x in range(self.n_devices)))
863+
self.run_test()
864+
865+
def test_replicated_sharding(self):
866+
self.device_mesh_shape = (1, self.n_devices)
867+
self.tensor_shape = (4, 4)
868+
self.partition_spec = (None, None)
869+
self.expected_str = '{replicated}'
870+
self.run_test()
871+
872+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
873+
f"Requires at least 4 devices.")
874+
def test_partial_replication_sharding(self):
875+
self.device_mesh_shape = (2, self.n_devices // 2)
876+
self.tensor_shape = (4, 4)
877+
self.partition_spec = (0, None)
878+
self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % (
879+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
880+
self.run_test()
881+
882+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
883+
f"Requires at least 4 devices.")
884+
def test_tupled_partial_replication_sharding(self):
885+
self.device_mesh_shape = (1, 2, self.n_devices // 2)
886+
self.tensor_shape = (16, 16)
887+
self.partition_spec = ((0, 1), None)
888+
self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
889+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
890+
self.run_test()
891+
892+
def test_tupled_partial_replication_sharding_with_transpose(self):
893+
self.device_mesh_shape = (1, 2, self.n_devices // 2)
894+
self.tensor_shape = (16, 16)
895+
self.partition_spec = (None, (2, 1))
896+
device_order = self.device_ids.reshape(self.device_mesh_shape).transpose(
897+
(2, 1, 0)).flatten()
898+
self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
899+
str(x) for x in device_order))
900+
self.run_test()
901+
902+
831903
if __name__ == '__main__':
832904
test = unittest.main()
833905
sys.exit(0 if test.result.wasSuccessful() else 1)

test/spmd/test_xla_sharding.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import unittest
77
from unittest.mock import patch
88
import sys
9-
import os
109

1110
import torch
1211
from torch import nn
@@ -27,16 +26,12 @@
2726
from torch_xla._internal import tpu
2827

2928

30-
def should_convert_to_shardy():
31-
return os.environ.get("CONVERT_SHLO_TO_SHARDY",
32-
"").lower() in ("1", "true", "yes")
33-
34-
3529
class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
3630

3731
@classmethod
3832
def setUpClass(cls):
3933
super().setUpClass()
34+
cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY")
4035

4136
def test_xla_sharded_tensor(self):
4237
partition_spec = (0, 1)
@@ -244,7 +239,7 @@ def test_custom_tile_assignment(self):
244239
if self.n_devices > 1:
245240
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
246241
[str(i) for i in reversed(range(self.n_devices))]))
247-
if should_convert_to_shardy():
242+
if self.convert_to_shardy:
248243
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
249244
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
250245

@@ -260,7 +255,7 @@ def test_mark_sharding_2d(self):
260255
if self.n_devices > 1:
261256
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
262257
[str(i) for i in range(self.n_devices)]))
263-
if should_convert_to_shardy():
258+
if self.convert_to_shardy:
264259
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
265260
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))
266261

@@ -281,7 +276,7 @@ def test_mark_sharding_4d(self):
281276
annotation = '{devices=[1,1,%d,%d]%s}' % (
282277
z_dim, self.n_devices // z_dim, ','.join(
283278
[str(i) for i in range(self.n_devices)]))
284-
if should_convert_to_shardy():
279+
if self.convert_to_shardy:
285280
annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices //
286281
z_dim, self.n_devices)
287282
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
@@ -418,7 +413,7 @@ def test_tupled_partition_spec(self):
418413
xs.mark_sharding(t, mesh, ((0, 1),))
419414
annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join(
420415
str(x) for x in range(self.n_devices)))
421-
if should_convert_to_shardy():
416+
if self.convert_to_shardy:
422417
annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices)
423418
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
424419

@@ -432,7 +427,7 @@ def test_named_partial_tupled_partition_spec(self):
432427
xs.mark_sharding(t, mesh, (('r', 'b'), None))
433428
annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
434429
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
435-
if should_convert_to_shardy():
430+
if self.convert_to_shardy:
436431
annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
437432
self.n_devices // 2, self.n_devices)
438433
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
@@ -442,7 +437,7 @@ def test_named_partial_tupled_partition_spec(self):
442437
xs.mark_sharding(u, mesh, (None, ('b', 'm')))
443438
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
444439
str(x) for x in range(self.n_devices)))
445-
if should_convert_to_shardy():
440+
if self.convert_to_shardy:
446441
annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices)
447442
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation)
448443

@@ -452,7 +447,7 @@ def test_named_partial_tupled_partition_spec(self):
452447
device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten()
453448
annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
454449
self.n_devices // 2, ','.join(str(x) for x in device_order))
455-
if should_convert_to_shardy():
450+
if self.convert_to_shardy:
456451
annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
457452
self.n_devices // 2, self.n_devices // 2)
458453
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)
@@ -463,7 +458,7 @@ def test_named_partial_tupled_partition_spec(self):
463458
device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten()
464459
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
465460
str(x) for x in device_order))
466-
if should_convert_to_shardy():
461+
if self.convert_to_shardy:
467462
annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices,
468463
self.n_devices // 2)
469464
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)
@@ -478,7 +473,7 @@ def test_multiple_tuples_in_spec(self):
478473
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
479474
annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join(
480475
str(x) for x in range(self.n_devices)))
481-
if should_convert_to_shardy():
476+
if self.convert_to_shardy:
482477
annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2,
483478
self.n_devices)
484479
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
@@ -491,7 +486,7 @@ def test_3d_tensor_2d_mesh(self):
491486
xs.mark_sharding(t, mesh, (None, 0, 1))
492487
annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(
493488
str(x) for x in range(self.n_devices)))
494-
if should_convert_to_shardy():
489+
if self.convert_to_shardy:
495490
annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2,
496491
self.n_devices)
497492
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
@@ -1013,8 +1008,7 @@ def test_op_sharding_cache(self):
10131008

10141009
t = torch.randn(1, self.n_devices).to('xla')
10151010
xs.mark_sharding(t, mesh, (0, 1))
1016-
counter_name = "CreateIotaOpSharding" if should_convert_to_shardy(
1017-
) else "CreateOpSharding"
1011+
counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding"
10181012
self.assertIn(counter_name, met.counter_names())
10191013
self.assertEqual(met.counter_value(counter_name), 1)
10201014

@@ -1435,7 +1429,7 @@ def test_data_loader_with_sharding(self):
14351429
data, _ = iter(train_device_loader).__next__()
14361430
self.assertEqual(data.size(), torch.Size([8, 3, 64, 64]))
14371431
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1438-
if should_convert_to_shardy():
1432+
if self.convert_to_shardy:
14391433
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
14401434
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)
14411435

@@ -1458,7 +1452,7 @@ def test_data_loader_with_non_batch_size(self):
14581452
data, _ = iter(train_device_loader).__next__()
14591453
self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64]))
14601454
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1461-
if should_convert_to_shardy():
1455+
if self.convert_to_shardy:
14621456
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
14631457
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)
14641458

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@
8282
#include "xla/pjrt/distributed/distributed.h"
8383
#include "xla/python/profiler/internal/traceme_wrapper.h"
8484

85-
#define PYBIND11_DETAILED_ERROR_MESSAGES
86-
8785
namespace torch_xla {
8886
namespace {
8987

@@ -754,6 +752,16 @@ std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
754752
return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode);
755753
}
756754

755+
std::optional<xla::OpSharding> GetXLAOpSharding(const at::Tensor& input) {
756+
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
757+
XLATensor::ShardingSpecPtr sharding_spec =
758+
xtensor ? xtensor->sharding_spec() : nullptr;
759+
if (sharding_spec != nullptr) {
760+
return sharding_spec->sharding;
761+
}
762+
return std::nullopt;
763+
}
764+
757765
std::string GetXLAShardingSpec(const XLATensorPtr xtensor) {
758766
auto sharding_spec = xtensor->sharding_spec();
759767
if (sharding_spec != nullptr) {
@@ -1519,6 +1527,10 @@ at::Tensor tensor_fromDLPack(PyObject* data) {
15191527
void InitXlaModuleBindings(py::module m) {
15201528
PythonScope<py::module> module(m);
15211529

1530+
using TileAssignmentDims = std::vector<int64_t>;
1531+
using ReshapeDims = std::vector<int64_t>;
1532+
using TransposePerm = std::vector<int>;
1533+
15221534
// Define the _XLAC.XlaShardingSpec class.
15231535
PythonScope<py::class_<XLATensor::ShardingSpec, XLATensor::ShardingSpecPtr>>(
15241536
m, "XlaShardingSpec")
@@ -1799,7 +1811,8 @@ void InitXlaModuleBindings(py::module m) {
17991811
}
18001812
})
18011813
.def("_xla_get_runtime_devices",
1802-
[]() { return runtime::GetComputationClientOrDie()->GetLocalDevices(); })
1814+
[]() {
1815+
return runtime::GetComputationClientOrDie()->GetLocalDevices(); })
18031816
.def("_xla_num_runtime_devices",
18041817
[]() -> int64_t {
18051818
return runtime::GetComputationClientOrDie()->GetNumLocalDevices();
@@ -2219,9 +2232,11 @@ void InitXlaModuleBindings(py::module m) {
22192232
return device.ordinal();
22202233
})
22212234
.def("_xla_get_process_index",
2222-
[]() { return runtime::GetComputationClientOrDie()->GetProcessIndex(); })
2235+
[]() {
2236+
return runtime::GetComputationClientOrDie()->GetProcessIndex(); })
22232237
.def("_xla_get_num_processes",
2224-
[]() { return runtime::GetComputationClientOrDie()->GetNumProcesses(); })
2238+
[]() {
2239+
return runtime::GetComputationClientOrDie()->GetNumProcesses(); })
22252240
.def("_xla_get_num_cached_compilation_graph",
22262241
[]() -> int64_t {
22272242
return XLAGraphExecutor::Get()->GetNumGraphHash();
@@ -2653,13 +2668,26 @@ void InitXlaModuleBindings(py::module m) {
26532668
})
26542669
.def("_get_xla_op_sharding",
26552670
[](const at::Tensor& input) -> std::optional<xla::OpSharding> {
2656-
XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input));
2657-
XLATensor::ShardingSpecPtr sharding_spec =
2658-
xtensor ? xtensor->sharding_spec() : nullptr;
2659-
if (sharding_spec != nullptr) {
2660-
return sharding_spec->sharding;
2671+
return GetXLAOpSharding(input);
2672+
})
2673+
.def("_get_xla_op_sharding_v2_params",
2674+
[](const at::Tensor& input) -> std::optional<std::tuple<TileAssignmentDims, ReshapeDims, TransposePerm, bool>> {
2675+
std::optional<xla::OpSharding> maybe_sharding =
2676+
GetXLAOpSharding(input);
2677+
if (!maybe_sharding) {
2678+
return std::nullopt;
26612679
}
2662-
return std::nullopt;
2680+
const xla::OpSharding& sharding = maybe_sharding.value();
2681+
TileAssignmentDims tile_assignment_dims(
2682+
sharding.tile_assignment_dimensions().begin(),
2683+
sharding.tile_assignment_dimensions().end());
2684+
ReshapeDims reshape_dims(sharding.iota_reshape_dims().begin(),
2685+
sharding.iota_reshape_dims().end());
2686+
TransposePerm transpose_perm(sharding.iota_transpose_perm().begin(),
2687+
sharding.iota_transpose_perm().end());
2688+
return std::make_tuple(tile_assignment_dims, reshape_dims,
2689+
transpose_perm,
2690+
sharding.replicate_on_last_tile_dim());
26632691
})
26642692
.def("_get_xla_sharding_specs",
26652693
[](const std::vector<at::Tensor>& tensors)

torch_xla/distributed/spmd/debugging.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import functools
33
import string
44
import sys
5-
from typing import Any, Callable, Optional, Union
5+
from typing import Any, Callable, Optional, Union, Tuple
66
import weakref
77

88
import numpy as np
@@ -157,12 +157,35 @@ def visualize_sharding(sharding: str,
157157
return table
158158

159159

160+
def construct_v1_sharding_str(t: torch.Tensor) -> str:
161+
"""
162+
Returns the corresponding HLO V1 sharding string from the tensor
163+
"""
164+
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
165+
if "<=" not in sharding:
166+
# This is already in the V1 format
167+
return sharding
168+
sharding_params = torch_xla._XLAC._get_xla_op_sharding_v2_params(t)
169+
assert sharding_params is not None
170+
tile_assignment_dims, reshape_dims, transpose_perm, replicate_on_last_dim = sharding_params
171+
num_devices = np.prod(reshape_dims)
172+
device_list = np.arange(num_devices).reshape(reshape_dims).transpose(
173+
transpose_perm).reshape(num_devices)
174+
175+
tile_assignment_str = ",".join(str(dim) for dim in tile_assignment_dims)
176+
device_list_str = ",".join(str(i) for i in device_list)
177+
replicate_str = " last_tile_dim_replicate" if replicate_on_last_dim else ""
178+
return f"{{devices=[{tile_assignment_str}]{device_list_str}{replicate_str}}}"
179+
180+
160181
def visualize_tensor_sharding(t, **kwargs):
161182
"""Visualizes an array's sharding."""
162183

163184
# XLAShardedTensor is-a torch.Tensor
164185
def maybe_unwrap(t: torch.Tensor) -> torch.Tensor:
165186
return t.global_tensor if isinstance(t, XLAShardedTensor) else t
166187

167-
sharding = torch_xla._XLAC._get_xla_sharding_spec(maybe_unwrap(t))
188+
t = maybe_unwrap(t)
189+
sharding = construct_v1_sharding_str(t)
190+
168191
return visualize_sharding(sharding, **kwargs)

0 commit comments

Comments
 (0)