Skip to content

Commit eef1f6c

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Support passing PartitionSpecs to ShapeDtypeStruct when there is a mesh in context.
PiperOrigin-RevId: 761322712
1 parent 56d4cb0 commit eef1f6c

File tree

3 files changed

+35
-10
lines changed

3 files changed

+35
-10
lines changed

jax/_src/api.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2826,17 +2826,30 @@ def __init__(self, shape, dtype, *, sharding=None, weak_type=False):
28262826
if dtype is None:
28272827
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
28282828
self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype)
2829-
if sharding is not None and not isinstance(sharding, (Sharding, Layout)):
2829+
if sharding is not None and not isinstance(sharding, (Sharding, Layout, P)):
28302830
raise ValueError(
2831-
"sharding should be an instance of `jax.sharding.Sharding` or"
2831+
"sharding should be an instance of `jax.sharding.Sharding`, "
2832+
"`jax.sharding.PartitionSpec` or"
28322833
f" `jax.experimental.layout.Layout`. Got {sharding} of type"
28332834
f" {type(sharding)}.")
28342835
if (isinstance(sharding, Layout) and
28352836
isinstance(sharding.device_local_layout, AutoLayout)):
28362837
raise TypeError(
28372838
"`DeviceLocalLayout.AUTO` cannot be used in place of a device-local"
28382839
f" layout in a `ShapeDtypeStruct`. Got {sharding}")
2839-
self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding
2840+
if isinstance(sharding, Layout):
2841+
self.sharding = sharding.sharding
2842+
elif isinstance(sharding, P):
2843+
# TODO(yashkatariya): Should this be abstract mesh?
2844+
cur_mesh = get_concrete_mesh()
2845+
if cur_mesh is None:
2846+
raise TypeError(
2847+
"When specifying PartitionSpec to `ShapeDtypeStruct`, the context"
2848+
" mesh cannot be empty. Please use `jax.sharding.use_mesh` to set"
2849+
" the mesh context.")
2850+
self.sharding = NamedSharding(cur_mesh, sharding)
2851+
else:
2852+
self.sharding = sharding
28402853
self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
28412854
self.weak_type = weak_type
28422855

tests/api_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4525,13 +4525,6 @@ def foo(x):
45254525
with self.assertRaisesRegex(TypeError, "applied to foo"):
45264526
f_vjp(1.0, 1.0)
45274527

4528-
def test_shapedtypestruct_sharding_error(self):
4529-
with self.assertRaisesRegex(
4530-
ValueError,
4531-
"sharding should be an instance of `jax.sharding.Sharding`."):
4532-
jax.ShapeDtypeStruct((8, 2), np.float32,
4533-
sharding=jax.sharding.PartitionSpec('x'))
4534-
45354528
def test_make_jaxpr_weakref(self):
45364529
class Foo(NamedTuple):
45374530
x: int

tests/pjit_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5006,6 +5006,25 @@ def test_sds_update(self):
50065006
with self.assertRaisesRegex(ValueError, "updating ShapeDtypeStruct"):
50075007
s4.update(sharding=NamedSharding(mesh, P('x')))
50085008

5009+
@jtu.with_explicit_mesh((2, 1), ('x', 'y'), axis_types=(AxisType.Auto,) * 2)
5010+
def test_sds_pspec_input(self, mesh):
5011+
inp = jax.ShapeDtypeStruct((2, 2), np.float32, sharding=P('x'))
5012+
lowered = jax.jit(lambda x: x * 2).lower(inp)
5013+
self.assertIn('num_partitions = 2', lowered.as_text())
5014+
5015+
np_inp = np.arange(4, dtype=np.float32).reshape(2, 2)
5016+
arr = jax.device_put(np_inp, P('x'))
5017+
out = lowered.compile()(arr)
5018+
self.assertArraysEqual(out, np_inp * 2)
5019+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
5020+
5021+
def test_sds_pspec_no_mesh_ctx_error(self):
5022+
with self.assertRaisesRegex(
5023+
TypeError,
5024+
'When specifying PartitionSpec to `ShapeDtypeStruct`, the context mesh'
5025+
' cannot be empty'):
5026+
jax.ShapeDtypeStruct((2, 2), np.float32, sharding=P('x'))
5027+
50095028

50105029
def spec_regex(s):
50115030
return str(s).replace(r"(", r"\(").replace(r")", r"\)")

0 commit comments

Comments
 (0)