Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion jax/_src/numpy/array_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
object = xc._xla.cuda_array_interface_to_buffer(
cai=cai, gpu_backend=backend, device_id=device_id)

leaves, treedef = tree_util.tree_flatten(object, is_leaf=lambda x: x is None)
# To handle nested lists & tuples, flatten the tree and process each leaf.
leaves, treedef = tree_util.tree_flatten(
object, is_leaf=lambda x: not isinstance(x, (list, tuple)))
if any(leaf is None for leaf in leaves):
raise ValueError("None is not a valid value for jnp.array")
leaves = [
Expand Down
49 changes: 44 additions & 5 deletions tests/array_extensibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import functools
from typing import Any, NamedTuple
from collections.abc import Callable
import dataclasses

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -30,19 +31,38 @@
config.parse_flags_with_absl()


@functools.partial(jax.tree_util.register_dataclass,
data_fields=['x'],
meta_fields=[])
@jax.tree_util.register_dataclass
@dataclasses.dataclass
class JaxArrayWrapper:
"""Class that provides a __jax_array__ method."""
x: ArrayLike

def __init__(self, x: ArrayLike):
self.x = x
def __jax_array__(self) -> jax.Array:
return jnp.asarray(self.x)


@jax.tree_util.register_dataclass
@dataclasses.dataclass
class NumpyArrayWrapper:
"""Pytree that provides an __array__ method."""
x: ArrayLike

def __array__(self, dtype=None, copy=None) -> jax.Array:
return np.asarray(self.x, dtype=dtype, copy=copy)


@jax.tree_util.register_dataclass
@dataclasses.dataclass
class JaxArrayWrapperWithErroringNumpyArray:
"""Pytree that provides an __array__ method which fails."""
x: ArrayLike

def __jax_array__(self) -> jax.Array:
return jnp.asarray(self.x)

def __array__(self, dtype=None, copy=None) -> jax.Array:
raise ValueError("__array__ method should not be called.")


class DuckTypedArrayWithErroringJaxArray:
"""Duck-typed array that provides a __jax_array__ method which fails."""
Expand Down Expand Up @@ -533,6 +553,25 @@ def test_numpy_api_supports_jax_array(self, api):

self.assertAllClose(wrapped, expected, atol=0, rtol=0)

@jtu.sample_product(
api=['array', 'asarray'],
test_class=[
JaxArrayWrapper,
NumpyArrayWrapper,
JaxArrayWrapperWithErroringNumpyArray,
],
)
def test_array_creation(self, api, test_class):
"""Test pytrees with __jax_array__ and/or __array__ methods."""
fun = getattr(jnp, api)
x = np.arange(5, dtype='float32')

expected = fun(x)
actual = fun(test_class(x))

self.assertIsInstance(actual, jax.Array)
self.assertAllClose(actual, expected, atol=0, rtol=0)

@parameterized.named_parameters(
{'testcase_name': func.__name__, 'func': func}
for func in [jnp.zeros_like, jnp.ones_like, jnp.empty_like, jnp.full_like]
Expand Down
Loading