Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/hijax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import jax
import jax.numpy as jnp
from jax import typeof
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding

from jax._src import config
from jax._src import core
Expand All @@ -42,6 +43,9 @@

config.parse_flags_with_absl()

if not jax._src.xla_bridge.backends_are_initialized():
jax.config.update('jax_num_cpu_devices', 8)

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

Expand Down Expand Up @@ -1140,6 +1144,32 @@ def f(box):
compiled(box)
self.assertAllClose(box.get(), 4.)

def test_jit_box_with_sharded_array(self):
# Create a mesh and sharded array
mesh = jax.make_mesh((4, 2), ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))

# Create a sharded array
x = jnp.arange(16.).reshape(8, 2)
sharded_x = jax.device_put(x, sharding)

# Put sharded array in a Box
box = Box(sharded_x)

@jax.jit
def f(box, y):
val = box.get()
result = val + y
box.set(result)
return box.get()

# Test with Box argument containing sharded array
y = jnp.ones((8, 2))
result = f(box, y)

self.assertAllClose(result, x + y)
self.assertAllClose(box.get(), x + y)


class RefTest(jtu.JaxTestCase):

Expand Down