Skip to content

Update shard_map.md's API specification section #28464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions docs/notebooks/shard_map.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -827,17 +827,20 @@
"Specs = PyTree[PartitionSpec]\n",
"\n",
"def shard_map(\n",
" f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,\n",
" auto: collections.abc.Set[AxisName] = frozenset([]),\n",
" f: Callable, /, *, out_specs: Specs, mesh: Mesh | None = None,\n",
" in_specs: Specs | None = None,\n",
" axis_names: collections.abc.Set[AxisName] = set(),\n",
" check_vma: bool = True,\n",
") -> Callable:\n",
" ...\n",
"```\n",
"where:\n",
"* communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`;\n",
"* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n",
"* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n",
"* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually;\n",
"* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; If None, mesh will be inferred from the\n",
"context which can be set via the `jax.sharding.use_mesh` context manager.\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"via" -> "via the"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"* `in_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express slicing/unconcatenation of inputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy). If None, all mesh axes must be of type `Explicit`, in which case the in_specs are inferred from the argument types;\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"affinely" is such a pretentious word to use here lol, sorry

you don't have to fix it here but we should say "zero or one times" instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"* `out_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express concatenation of outputs, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n",
"* `axis_names` is an optional set of axis names corresponding to the subset of names of `mesh` to treat manual in the body. If empty, `f` is manual over all axes of the mesh.\n",
"* `check_vma` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).\n",
"\n",
"The shapes of the arguments passed to `f` have the same ranks as the arguments\n",
Expand Down
13 changes: 8 additions & 5 deletions docs/notebooks/shard_map.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,17 +554,20 @@ from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]

def shard_map(
f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
auto: collections.abc.Set[AxisName] = frozenset([]),
f: Callable, /, *, out_specs: Specs, mesh: Mesh | None = None,
in_specs: Specs | None = None,
axis_names: collections.abc.Set[AxisName] = set(),
check_vma: bool = True,
) -> Callable:
...
```
where:
* communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`;
* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;
* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;
* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually;
* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; If None, mesh will be inferred from the
context which can be set via the `jax.sharding.use_mesh` context manager.
* `in_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express slicing/unconcatenation of inputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy). If None, all mesh axes must be of type `Explicit`, in which case the in_specs are inferred from the argument types;
* `out_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express concatenation of outputs, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;
* `axis_names` is an optional set of axis names corresponding to the subset of names of `mesh` to treat manual in the body. If empty, `f` is manual over all axes of the mesh.
* `check_vma` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).

The shapes of the arguments passed to `f` have the same ranks as the arguments
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(),
the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a
``mesh`` axis name at a position expresses sharding the corresponding
argument array axis along that positional axis; not mentioning an axis
name expresses replication. If ``None``, all mesh axes must be in explicit
mode, in which case the in_specs are inferred from the argument types.
name expresses replication. If ``None``, all mesh axes must be of type
`Explicit`, in which case the in_specs are inferred from the argument types.
out_specs: a pytree with ``PartitionSpec`` instances as leaves, with a tree
structure that is a tree prefix of the output of ``f``. Each
``PartitionSpec`` represents how the corresponding output shards should be
Expand All @@ -107,8 +107,8 @@ def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(),
corresponding positional axis; not mentioning a ``mesh`` axis name
expresses a promise that the output values are equal along that mesh axis,
and that rather than concatenating only a single value should be produced.
axis_names: (optional, default None) set of axis names from ``mesh`` over
which the function ``f`` is manual. If ``None``, ``f``, is manual
axis_names: (optional, default set()) set of axis names from ``mesh`` over
which the function ``f`` is manual. If empty, ``f``, is manual
over all mesh axes.
check_vma: (optional) boolean (default True) representing whether to enable
additional validity checks and automatic differentiation optimizations.
Expand Down
Loading