Skip to content

jax.numpy: make type stubs consistent with runtime #28721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 31 additions & 32 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ...
def broadcast_shapes(*shapes: Sequence[int | _core.Tracer]
) -> tuple[int | _core.Tracer, ...]: ...

def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: ...
def broadcast_to(array: ArrayLike, shape: DimSize | Shape, *,
out_sharding: NamedSharding | P | None = None) -> Array: ...
c_: _CClass
can_cast = _np.can_cast
def cbrt(x: ArrayLike, /) -> Array: ...
Expand All @@ -267,6 +268,7 @@ def clip(
/,
min: ArrayLike | None = ...,
max: ArrayLike | None = ...,
*,
a: ArrayLike | DeprecatedArg | None = ...,
a_min: ArrayLike | DeprecatedArg | None = ...,
a_max: ArrayLike | DeprecatedArg | None = ...
Expand All @@ -278,7 +280,7 @@ complex128: Any
complex64: Any
complex_: Any
complexfloating = _np.complexfloating
def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ...,
def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ..., *,
size: int | None = ..., fill_value: ArrayLike = ..., out: None = ...) -> Array: ...
def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ...
def concatenate(
Expand Down Expand Up @@ -314,9 +316,9 @@ def cross(
axis: int | None = ...,
) -> Array: ...
csingle: Any
def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ...,
def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ...,
out: None = ...) -> Array: ...
def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ...,
def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ...,
out: None = ...) -> Array: ...
def cumulative_prod(x: ArrayLike, /, *, axis: int | None = ...,
dtype: DTypeLike | None = ...,
Expand Down Expand Up @@ -371,7 +373,6 @@ def einsum(
optimize: str | builtins.bool | list[tuple[int, ...]] = ...,
precision: PrecisionLike = ...,
preferred_element_type: DTypeLike | None = ...,
_use_xeinsum: builtins.bool = False,
_dot_general: Callable[..., Array] = ...,
out_sharding: NamedSharding | P | None = ...,
) -> Array: ...
Expand All @@ -385,7 +386,6 @@ def einsum(
optimize: str | builtins.bool | list[tuple[int, ...]] = ...,
precision: PrecisionLike = ...,
preferred_element_type: DTypeLike | None = ...,
_use_xeinsum: builtins.bool = False,
_dot_general: Callable[..., Array] = ...,
out_sharding: NamedSharding | P | None = ...,
) -> Array: ...
Expand All @@ -397,7 +397,6 @@ def einsum(
optimize: str | builtins.bool | list[tuple[int, ...]] = ...,
precision: PrecisionLike = ...,
preferred_element_type: DTypeLike | None = ...,
_use_xeinsum: builtins.bool = ...,
_dot_general: Callable[..., Array] = ...,
out_sharding: NamedSharding | P | None = ...,
) -> Array: ...
Expand All @@ -422,7 +421,7 @@ def einsum_path(
optimize: str | builtins.bool | list[tuple[int, ...]] = ...,
) -> tuple[list[tuple[int, ...]], Any]: ...

def empty(shape: Any, dtype: DTypeLike | None = ...,
def empty(shape: Any, dtype: DTypeLike | None = ..., *,
device: _Device | _Sharding | None = ...) -> Array: ...
def empty_like(prototype: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = ...,
Expand Down Expand Up @@ -579,17 +578,17 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = .
def invert(x: ArrayLike, /) -> Array: ...
def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ...,
atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ...
def iscomplex(m: ArrayLike) -> Array: ...
def iscomplex(x: ArrayLike) -> Array: ...
def iscomplexobj(x: Any) -> builtins.bool: ...
def isdtype(dtype: DTypeLike, kind: DType | str | tuple[DType | str, ...]) -> builtins.bool: ...
def isfinite(x: ArrayLike, /) -> Array: ...
def isin(element: ArrayLike, test_elements: ArrayLike,
assume_unique: builtins.bool = ..., invert: builtins.bool = ..., method: str = ...) -> Array: ...
def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: builtins.bool = ...,
invert: builtins.bool = ..., *, method: str = ...) -> Array: ...
def isinf(x: ArrayLike, /) -> Array: ...
def isnan(x: ArrayLike, /) -> Array: ...
def isneginf(x: ArrayLike, /) -> Array: ...
def isposinf(x: ArrayLike, /) -> Array: ...
def isreal(m: ArrayLike) -> Array: ...
def isreal(x: ArrayLike) -> Array: ...
def isrealobj(x: Any) -> builtins.bool: ...
def isscalar(element: Any) -> builtins.bool: ...
def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> builtins.bool: ...
Expand Down Expand Up @@ -644,7 +643,7 @@ def logspace(start: ArrayLike, stop: ArrayLike, num: int = ...,
endpoint: builtins.bool = ..., base: ArrayLike = ...,
dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ...
def mask_indices(
n: int, mask_func: Callable, k: int = ...
n: int, mask_func: Callable, k: int = ..., *, size: int | None = ...
) -> tuple[Array, ...]: ...
def matmul(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ...,
Expand All @@ -655,7 +654,7 @@ def max(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,
where: ArrayLike | None = ...) -> Array: ...
def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ...,
out: None = ..., keepdims: builtins.bool = ..., *,
where: ArrayLike | None = ...) -> Array: ...
def median(a: ArrayLike, axis: int | tuple[int, ...] | None = ...,
Expand Down Expand Up @@ -689,14 +688,14 @@ def nanargmin(
out: None = ...,
keepdims: builtins.bool | None = ...,
) -> Array: ...
def nancumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ...,
def nancumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ...,
out: None = ...) -> Array: ...
def nancumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ...,
def nancumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ...,
out: None = ...) -> Array: ...
def nanmax(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,
where: ArrayLike | None = ...) -> Array: ...
def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ...,
out: None = ...,
keepdims: builtins.bool = ...,
where: ArrayLike | None = ...) -> Array: ...
Expand All @@ -710,21 +709,21 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = ...,
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ...,
out: None = ...,
keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,
where: ArrayLike | None = ...) -> Array: ...
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ...,
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...,
ddof: int = ..., keepdims: builtins.bool = ...,
def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ...,
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ...,
where: ArrayLike | None = ...) -> Array: ...
def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ...,
out: None = ..., keepdims: builtins.bool = ...,
initial: ArrayLike | None = ...,
where: ArrayLike | None = ...) -> Array: ...
def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ...,
out: None = ...,
ddof: int = 0, keepdims: builtins.bool = False,
where: ArrayLike | None = ...) -> Array: ...
Expand All @@ -740,7 +739,7 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ...
number = _np.number
object_ = _np.object_
ogrid: _Ogrid
def ones(shape: Any, dtype: DTypeLike | None = ...,
def ones(shape: Any, dtype: DTypeLike | None = ..., *,
device: _Device | _Sharding | None = ...) -> Array: ...
def ones_like(a: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = ...,
Expand Down Expand Up @@ -782,7 +781,7 @@ def positive(x: ArrayLike, /) -> Array: ...
def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def power(x: ArrayLike, y: ArrayLike, /) -> Array: ...
printoptions = _np.printoptions
def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ...,
out: None = ..., keepdims: builtins.bool = ...,
initial: ArrayLike | None = ..., where: ArrayLike | None = ...,
promote_integers: builtins.bool = ...) -> Array: ...
Expand All @@ -805,7 +804,6 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int],
mode: str = ..., order: str = ...) -> Array: ...
def real(x: ArrayLike, /) -> Array: ...
def reciprocal(x: ArrayLike, /) -> Array: ...
register_jax_array_methods: Any
def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *,
total_repeat_length: int | None = ...,
Expand Down Expand Up @@ -844,7 +842,8 @@ def setdiff1d(
size: int | None = ...,
fill_value: ArrayLike | None = ...,
) -> Array: ...
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ...
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ..., *,
size: int | None = ..., fill_value: ArrayLike | None = ...) -> Array: ...
def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: ...
def sign(x: ArrayLike, /) -> Array: ...
def signbit(x: ArrayLike, /) -> Array: ...
Expand Down Expand Up @@ -882,14 +881,14 @@ def stack(
out: None = ...,
dtype: DTypeLike | None = ...,
) -> Array: ...
def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ...,
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ...
subtract: BinaryUfunc
def sum(
a: ArrayLike,
axis: _Axis = ...,
dtype: DTypeLike = ...,
dtype: DTypeLike | None = ...,
out: None = ...,
keepdims: builtins.bool = ...,
initial: ArrayLike | None = ...,
Expand Down Expand Up @@ -927,7 +926,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = ...) -> Array: ...
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ...,
axis: int = ...) -> Array: ...
def tri(
N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike = ...
N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike | None = ...
) -> Array: ...
def tril(m: ArrayLike, k: int = ...) -> Array: ...
def tril_indices(
Expand Down Expand Up @@ -970,7 +969,7 @@ class _UniqueInverseResult(NamedTuple):
def unique(ar: ArrayLike, return_index: builtins.bool = ..., return_inverse: builtins.bool = ...,
return_counts: builtins.bool = ..., axis: int | None = ...,
*, equal_nan: builtins.bool = ..., size: int | None = ...,
fill_value: ArrayLike | None = ...
fill_value: ArrayLike | None = ..., sorted: bool = ...,
): ...
def unique_all(x: ArrayLike, /, *, size: int | None = ...,
fill_value: ArrayLike | None = ...) -> _UniqueAllResult: ...
Expand All @@ -994,7 +993,7 @@ def unwrap(p: ArrayLike, discont: ArrayLike | None = ...,
def vander(
x: ArrayLike, N: int | None = ..., increasing: builtins.bool = ...
) -> Array: ...
def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ...,
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ...
def vdot(
Expand Down Expand Up @@ -1029,7 +1028,7 @@ def where(condition: ArrayLike, x: ArrayLike | None = ...,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ...
) -> Array | tuple[Array, ...]: ...

def zeros(shape: Any, dtype: DTypeLike | None = ...,
def zeros(shape: Any, dtype: DTypeLike | None = ..., *,
device: _Device | _Sharding | None = ...) -> Array: ...
def zeros_like(a: ArrayLike | DuckTypedArray,
dtype: DTypeLike | None = ...,
Expand Down
Loading