Skip to content

Commit

Permalink
#sdy add repr for Sdy ArraySharding and DimSharding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713333974
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Jan 8, 2025
1 parent 5511949 commit db5e2d2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
20 changes: 20 additions & 0 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ def build(self) -> sdy.DimensionShardingAttr:
is_closed=self.is_closed,
priority=self.priority)

def __repr__(self):
axes_repr = ', '.join(f"'{a}'" for a in self.axes)
open_repr = '' if self.is_closed else ', ?'
priority_repr = '' if self.priority is None else f'p{self.priority}'
return f'SdyDimSharding({{{axes_repr}{open_repr}}}{priority_repr})'

def _custom_repr(self):
axes_repr = ', '.join(f"'{a}'" for a in self.axes)
open_repr = '' if self.is_closed else ', ?'
priority_repr = '' if self.priority is None else f'p{self.priority}'
return f'{{{axes_repr}{open_repr}}}{priority_repr}'


@dataclasses.dataclass
class SdyArraySharding:
Expand All @@ -144,6 +156,14 @@ def build(self) -> sdy.TensorShardingAttr:
mesh_attr,
[dim_sharding.build() for dim_sharding in self.dimension_shardings])

def __repr__(self):
dim_sharding_repr = ', '.join(dim_sharding._custom_repr() for dim_sharding
in self.dimension_shardings)
device_id_repr = ''
if self.logical_device_ids is not None:
device_id_repr = f', device_ids={self.logical_device_ids}'
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr})"


@util.cache(max_size=4096, trace_context_in_key=False)
def named_sharding_to_xla_hlo_sharding(
Expand Down
24 changes: 24 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6625,6 +6625,30 @@ def f(x, y):
lowered_str = jax.jit(f, in_shardings=[AUTO(mesh), AUTO(mesh)]).lower(x, x).as_text()
self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str)

def test_array_sharding_repr_with_priority(self):
sharding = sharding_impls.SdyArraySharding(
mesh_shape=(('data', 4), ('model', 8), ('expert', 2)),
dimension_shardings=[
sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_closed=True),
sharding_impls.SdyDimSharding(axes=['model'], is_closed=False, priority=2)])
self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])")

def test_array_sharding_repr_with_logical_ids(self):
sharding = sharding_impls.SdyArraySharding(
mesh_shape=(('data', 4), ('model', 8)),
dimension_shardings=[
sharding_impls.SdyDimSharding(axes=['data'], is_closed=True)],
logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3])
self.assertEqual(repr(sharding),
"SdyArraySharding([{'data'}], "
"device_ids=[4, 5, 6, 7, 0, 1, 2, 3])")

def test_dimension_sharding_repr(self):
dim_sharding = sharding_impls.SdyDimSharding(
axes=['data', 'model'], is_closed=False, priority=2)
self.assertEqual(repr(dim_sharding),
"SdyDimSharding({'data', 'model', ?}p2)")


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit db5e2d2

Please sign in to comment.