Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In progress experimention for supporting JAX Arrays with variable-width strings (i.e., with dtype = StringDType). #25535

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

import numpy as np

try:
import numpy.dtypes as np_dtypes # pylint: disable=g-import-not-at-top
except ImportError:
np_dtypes = None # type: ignore

from jax._src import core
from jax._src import dtypes

Expand Down Expand Up @@ -54,7 +59,8 @@ def masked_array_error(*args, **kwargs):

def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
if not (hasattr(np_dtypes, "StringDType") and isinstance(dtype, np_dtypes.StringDType)): # type: ignore
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))

core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
Expand Down
10 changes: 9 additions & 1 deletion jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@

import numpy as np

try:
import numpy.dtypes as np_dtypes
except ImportError:
np_dtypes = None # type: ignore

from jax._src import core
from jax._src import dtypes
from jax._src.abstract_arrays import numpy_scalar_types
Expand Down Expand Up @@ -614,7 +619,10 @@ def _str_abstractify(x):

def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)

if not (hasattr(np_dtypes, "StringDType") and isinstance(dtype, np_dtypes.StringDType)): # type: ignore
dtypes.check_valid_dtype(dtype)

return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
Expand Down
13 changes: 13 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
PartitionSpec as P)
from jax.tree_util import tree_flatten, tree_leaves, tree_map
import numpy as np

try:
from numpy import dtypes as np_dtypes
except ImportError:
np_dtypes = None # type: ignore
import opt_einsum

export = set_module('jax.numpy')
Expand Down Expand Up @@ -5572,6 +5577,14 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# Keep the output uncommitted.
return jax.device_put(object)

# 2DO: Comment.
if isinstance(object, np.ndarray) and (np_dtypes is not None) and (getattr(np_dtypes, "StringDType", None) is not None) and (isinstance(object.dtype, np_dtypes.StringDType)): # type: ignore
if (ndmin > 0) and (ndmin != object.ndim):
raise TypeError(
f"ndmin {ndmin} does not match ndims {object.ndim} of input array"
)
return jax.device_put(x=object, device=device)

# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
Expand Down
Loading