Skip to content

Commit 39b57e1

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
fix-forward for pallas tpu memory spaces test
PiperOrigin-RevId: 770929681
1 parent e04cc28 commit 39b57e1

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def _get_memory_space_from_aval(
8383
return None
8484
case tpu_core.MemorySpace.ANY:
8585
return None
86+
case tpu_core.MemorySpace.HBM:
87+
return tpu_custom_call.MemorySpace.HBM
8688
case tpu_core.MemorySpace.VMEM:
8789
return tpu_custom_call.MemorySpace.VMEM
8890
case tpu_core.MemorySpace.SMEM:

jax/_src/pallas/mosaic/primitives.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,8 @@ def with_memory_space_constraint(
802802
Returns:
803803
The array `x` with the memory space constraint.
804804
"""
805+
if memory_space in {tpu_core.MemorySpace.ANY, pl_core.MemorySpace.ANY}:
806+
return x
805807
if memory_space not in {tpu_core.MemorySpace.HBM, tpu_core.MemorySpace.VMEM}:
806808
raise NotImplementedError(
807809
"with_memory_space_constraint only supports HBM and VMEM."

tests/pallas/tpu_pallas_memory_space_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ def g(x):
5555
@jax.jit
5656
def f(x):
5757
x = pltpu.with_memory_space_constraint(x, memory_space=memory_space)
58-
self.assertEqual(pltpu.get_memory_space(x), memory_space)
58+
if color is None:
59+
self.assertIsNone(pltpu.get_memory_space(x))
60+
else:
61+
self.assertEqual(pltpu.get_memory_space(x), memory_space)
5962
x = g(x)
6063
return x
6164

0 commit comments

Comments
 (0)