diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 69e2adc4dc77..a8305e8bd772 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -123,6 +123,18 @@ def build(self) -> sdy.DimensionShardingAttr: is_closed=self.is_closed, priority=self.priority) + def __repr__(self): + axes_repr = ', '.join(f"'{a}'" for a in self.axes) + open_repr = '' if self.is_closed else ', ?' + priority_repr = '' if self.priority is None else f'p{self.priority}' + return f'SdyDimSharding({{{axes_repr}{open_repr}}}{priority_repr})' + + def _custom_repr(self): + axes_repr = ', '.join(f"'{a}'" for a in self.axes) + open_repr = '' if self.is_closed else ', ?' + priority_repr = '' if self.priority is None else f'p{self.priority}' + return f'{{{axes_repr}{open_repr}}}{priority_repr}' + @dataclasses.dataclass class SdyArraySharding: @@ -144,6 +156,14 @@ def build(self) -> sdy.TensorShardingAttr: mesh_attr, [dim_sharding.build() for dim_sharding in self.dimension_shardings]) + def __repr__(self): + dim_sharding_repr = ', '.join(dim_sharding._custom_repr() for dim_sharding + in self.dimension_shardings) + device_id_repr = '' + if self.logical_device_ids is not None: + device_id_repr = f', device_ids={self.logical_device_ids}' + return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr})" + @util.cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3fcc5c81ad91..59f2233caeb0 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6625,6 +6625,30 @@ def f(x, y): lowered_str = jax.jit(f, in_shardings=[AUTO(mesh), AUTO(mesh)]).lower(x, x).as_text() self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str) + def test_array_sharding_repr_with_priority(self): + sharding = sharding_impls.SdyArraySharding( + mesh_shape=(('data', 4), ('model', 8), ('expert', 2)), + dimension_shardings=[ + sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_closed=True), + sharding_impls.SdyDimSharding(axes=['model'], is_closed=False, priority=2)]) + self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])") + + def test_array_sharding_repr_with_logical_ids(self): + sharding = sharding_impls.SdyArraySharding( + mesh_shape=(('data', 4), ('model', 8)), + dimension_shardings=[ + sharding_impls.SdyDimSharding(axes=['data'], is_closed=True)], + logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3]) + self.assertEqual(repr(sharding), + "SdyArraySharding([{'data'}], " + "device_ids=[4, 5, 6, 7, 0, 1, 2, 3])") + + def test_dimension_sharding_repr(self): + dim_sharding = sharding_impls.SdyDimSharding( + axes=['data', 'model'], is_closed=False, priority=2) + self.assertEqual(repr(dim_sharding), + "SdyDimSharding({'data', 'model', ?}p2)") + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())