Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better docs for jnp.fromfunction #24412

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6182,9 +6182,63 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
return from_dlpack(x, device=device, copy=copy)

@util.implements(np.fromfunction)

def fromfunction(function: Callable[..., Array], shape: Any,
*, dtype: DTypeLike = float, **kwargs) -> Array:
"""Create an array from a function applied over indices.

JAX implementation of :func:`numpy.fromfunction`. The JAX implementation
differs in that it dispatches via :func:`jax.vmap`, and so unlike in NumPy
the function logically operates on scalar inputs, and need not explicitly
handle broadcasted inputs.

Args:
function: a function that takes *N* dynamic scalars and outputs a scalar.
shape: a length-*N* tuple of integers specifying the output shape.
dtype: optionally specify the dtype of the inputs. Defaults to floating-point.
kwargs: additional keyword arguments are passed statically to ``function``.

Returns:
An array of shape ``shape`` if ``function`` returns a scalar, or in general
a pytree of arrays with leading dimensions ``shape``, as determined by the
output of ``function``.

See also:
- :func:`jax.vmap`: the core transformation that the :func:`fromfunction`
API is built on.

Examples:
Generate a multiplication table of a given shape:

>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int)
Array([[ 0, 0, 0, 0, 0, 0],
[ 0, 1, 2, 3, 4, 5],
[ 0, 2, 4, 6, 8, 10]], dtype=int32)

When ``function`` returns a non-scalar the output will have leading
dimension of ``shape``:

>>> def f(x):
... return (x + 1) * jnp.arange(3)
>>> jnp.fromfunction(f, shape=(2,))
Array([[0., 1., 2.],
[0., 2., 4.]], dtype=float32)

``function`` may return multiple results, in which case each is mapped
independently:

>>> def f(x, y):
... return x + y, x * y
>>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5))
>>> print(x_plus_y)
[[0. 1. 2. 3. 4.]
[1. 2. 3. 4. 5.]
[2. 3. 4. 5. 6.]]
>>> print(x_times_y)
[[0. 0. 0. 0. 0.]
[0. 1. 2. 3. 4.]
[0. 2. 4. 6. 8.]]
"""
shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()")
for i in range(len(shape)):
in_axes = [0 if i == j else None for j in range(len(shape))]
Expand Down
Loading