Skip to content

Commit 16ca0c2

Browse files
author
Mark Lee
authored
Supports flexible input partition specs. (apple#933)
* Supports flexible input partition specs in causal lm. * Moves the input partitioning to Input. * Adds missing pytest marker. * Address review comments. * Rebase and update golden configs. * Fixes batch axis names and adds a test.
1 parent 9b75ef1 commit 16ca0c2

File tree

86 files changed

+1068
-328
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+1068
-328
lines changed

CHANGELOG.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
## 0.1.4
44

55
* Changes
6-
* Upgrade Jax from 0.4.33 to 0.4.34.
6+
* Upgrade Jax from 0.4.33 to 0.4.34.
7+
* Updates the `input_base.Input` API to support configuring input partitioning behavior.
8+
* The config fields `batch_axis_names` and `seq_axis_names` in `causal_lm.Model` are now deprecated. Please use `input_base.Input.input_partitioner` instead.
79

810
## 0.1.3
911

1012
* Changes
11-
* Upgrade Jax from 0.4.30 to 0.4.33.
13+
* Upgrade Jax from 0.4.30 to 0.4.33.
1214

1315
## 0.1.2
1416

axlearn/common/causal_lm.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,13 @@ class Config(BaseModel.Config):
5353
decoder: Decoder.Config = Decoder.default_config()
5454
# An auxiliary z-loss scale. If >0 encourages the softmax normalizer to be well behaved.
5555
z_loss_scale: float = 0.0
56-
# Batch mesh axis name(s).
56+
# TODO(markblee): Remove `batch_axis_names` and `seq_axis_names`. Input sharding should
57+
# happen at `Input.dispatch_global_batch` instead.
58+
# Batch mesh axis name(s). (Deprecated.)
5759
# These will be used to constrain the batch (first) axis of relevant inputs.
58-
batch_axis_names: tuple[str] = ("data",)
59-
# Sequence-parallel mesh axis name(s).
60+
# If None, no batch dim constraints are applied, rather than replicating across batch dim.
61+
batch_axis_names: Optional[tuple[str]] = ("data",)
62+
# Sequence-parallel mesh axis name(s). (Deprecated.)
6063
# These will be used to constrain the sequence axis of relevant inputs.
6164
# If None, no batch sequence dim constraints are applied.
6265
seq_axis_names: Optional[tuple[str]] = None
@@ -342,8 +345,18 @@ def _constrain_input_batch(self, input_batch: NestedTensor):
342345
mesh = thread_resources.env.physical_mesh # type: ignore
343346
if mesh.empty or mesh.size == 1:
344347
return
348+
cfg: Model.Config = self.config
349+
if cfg.batch_axis_names is None and cfg.seq_axis_names is None:
350+
return
351+
352+
logging.log_first_n(
353+
logging.WARNING,
354+
"cfg.batch_axis_names and cfg.seq_axis_names are deprecated. "
355+
"Dispatch inputs using `Input.dispatch_global_batch` instead. "
356+
"See `input_base.Input.input_partitioner` for more details.",
357+
1,
358+
)
345359

346-
cfg = self.config
347360
for k, v in input_batch.items():
348361
if k in [
349362
"input_ids",

axlearn/common/causal_lm_test.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright © 2023 Apple Inc.
22

33
"""Tests autoregressive models."""
4+
45
from functools import partial
56

67
import jax
@@ -341,9 +342,13 @@ def test_forward(self):
341342
self.assertAlmostEqual(loss, ref_outputs["loss"])
342343
self.assertTrue(jnp.allclose(aux["per_label_loss"], ref_outputs["per_token_loss"]))
343344

345+
# TODO(markblee): Add a pytest marker for multi-device tests.
344346
@pytest.mark.skipif(
345347
jax.device_count() != 4 or jax.process_count() != 1,
346-
reason="Incorrect device & process count for mesh.",
348+
reason=(
349+
"Incorrect device & process count for mesh.\n"
350+
"Use XLA_FLAGS=--xla_force_host_platform_device_count=4 to run locally."
351+
),
347352
)
348353
def test_constrain_input_batch(self):
349354
model = (
@@ -398,10 +403,10 @@ def fn(x):
398403
# Get stable-hlo representation.
399404
hlo_text = fn.lower(input_batch).compiler_ir(dialect="hlo").as_hlo_text()
400405

401-
# Five (out of six) tensors were sharded.
402-
self.assertEqual(hlo_text.count('custom_call_target="Sharding"'), 5)
406+
# Seven (out of eight) tensors were sharded.
407+
self.assertEqual(hlo_text.count('custom_call_target="Sharding"'), 7)
403408
# For the [batch, seq_len] tensors.
404-
self.assertEqual(hlo_text.count("sharding={devices=[2,2]<=[4]}"), 4)
409+
self.assertEqual(hlo_text.count("sharding={devices=[2,2]<=[4]}"), 6)
405410
# For the [batch,] tensor.
406411
self.assertEqual(
407412
hlo_text.count("sharding={devices=[2,2]<=[4] last_tile_dim_replicate}"), 1

axlearn/common/input_base.py

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,121 @@
22

33
"""Base Input interface."""
44

5-
from typing import Iterable, Iterator, Optional, Sequence, Union
5+
import re
6+
from typing import Iterable, Iterator, NamedTuple, Optional, Protocol, Sequence, Union
67

78
import jax
9+
from absl import logging
10+
from jax._src.mesh import thread_resources
811
from jax.sharding import PartitionSpec
912

10-
from axlearn.common.config import config_class
13+
from axlearn.common.config import ConfigOr, config_class, maybe_instantiate
1114
from axlearn.common.input_dispatch import InputDispatcher
1215
from axlearn.common.module import Module
1316
from axlearn.common.utils import (
1417
Nested,
1518
Tensor,
1619
as_numpy_array,
1720
dispatch_input_batch,
21+
tree_paths,
1822
with_sharding_constraint,
1923
)
2024

2125

26+
class InputPartitionFn(Protocol):
27+
"""Partitions the input batch."""
28+
29+
def __call__(self, input_batch: Nested[Tensor]) -> Nested[Tensor]:
30+
"""Applies sharding constraints to `input_batch` and returns the modified batch.
31+
32+
Implementations should avoid making in-place updates to `input_batch`.
33+
"""
34+
35+
36+
class PathAndRank(NamedTuple):
37+
"""A tuple (path, rank) used for matching against inputs in a batch.
38+
39+
Attributes:
40+
path: An optional path or path regex. None means match everything.
41+
rank: An optional rank (ndim). None means match everything.
42+
"""
43+
44+
path: Optional[Union[str, re.Pattern]]
45+
rank: Optional[int]
46+
47+
48+
def partition_by_path_rank(
49+
path_rank_to_partition: dict[PathAndRank, PartitionSpec],
50+
) -> InputPartitionFn:
51+
"""Partitions the keys in the input batch by Tensor path and rank (ndim).
52+
53+
If not within a mesh, the partition fn is a no-op.
54+
55+
Args:
56+
path_rank_to_partition: A mapping from (path_regex, rank) to partition spec.
57+
For each input path, the Tensor will be constrained by the first matching
58+
(path_regex, rank) rule, where paths are full-matched against `path_regex` and ranks are
59+
matched against `rank`.
60+
`path_regex` or `rank` are allowed to be None to match everything.
61+
If replication is desired, specify a partition spec of None explicitly.
62+
If leaving the input unconstrained is desired, specify a partition spec of
63+
`PartitionSpec.UNCONSTRAINED` explicitly.
64+
65+
Returns:
66+
A function that applies sharding constraints to an input batch and returns a new batch.
67+
68+
Raises:
69+
ValueError: If no rules match for a given input, which is likely an oversight. If leaving
70+
inputs unconstrained is desired, explicitly specify `PartitionSpec.UNCONSTRAINED`.
71+
72+
Example:
73+
To constrain all rank-1 Tensors by ("data",) and rank-2 by ("data", "seq"):
74+
```
75+
partition_by_path_ndim({
76+
(".*", 1): PartitionSpec("data"),
77+
(".*", 2): PartitionSpec("data", "seq"),
78+
})
79+
```
80+
"""
81+
compiled = {}
82+
for (regex, rank), spec in path_rank_to_partition.items():
83+
if regex is not None:
84+
regex = re.compile(regex)
85+
compiled[(regex, rank)] = spec
86+
87+
def fn(input_batch: Nested[Tensor]) -> Nested[Tensor]:
88+
mesh = thread_resources.env.physical_mesh # type: ignore
89+
if mesh.empty or mesh.size == 1:
90+
return input_batch
91+
92+
def maybe_constrain(path: str, value: Tensor):
93+
for (path_regex, rank), partition_spec in compiled.items():
94+
if not (rank is None or value.ndim == rank) or not (
95+
path_regex is None or re.fullmatch(path_regex, path)
96+
):
97+
continue
98+
if partition_spec is not PartitionSpec.UNCONSTRAINED:
99+
value = with_sharding_constraint(value, partition_spec)
100+
logging.log_first_n(
101+
logging.INFO,
102+
"Constraining input_batch[%s] with %s.",
103+
len(input_batch),
104+
path,
105+
partition_spec,
106+
)
107+
return value
108+
# No rules match. We raise as not-constraining is likely an oversight.
109+
raise ValueError(
110+
f"No rules matched input_batch['{path}']. "
111+
"If you intended to leave the input unconstrained, "
112+
"specify `PartitionSpec.UNCONSTRAINED` explicitly."
113+
)
114+
115+
return jax.tree_map(maybe_constrain, tree_paths(input_batch), input_batch)
116+
117+
return fn
118+
119+
22120
class Input(Module):
23121
"""A Module to generate input batches.
24122
@@ -53,9 +151,12 @@ class Config(Module.Config):
53151
Attributes:
54152
input_dispatcher: If not None, creates an InputDispatcher and uses it for dispatching
55153
per-feed batches to global batches.
154+
input_partitioner: If not None, applies additional sharding constraints on each input
155+
batch during `dispatch_global_batch`.
56156
"""
57157

58158
input_dispatcher: Optional[InputDispatcher.Config] = None
159+
input_partitioner: Optional[ConfigOr[InputPartitionFn]] = None
59160

60161
def __init__(self, cfg: Config, *, parent: Optional[Module]):
61162
super().__init__(cfg, parent=parent)
@@ -64,6 +165,9 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
64165
self.input_dispatcher: InputDispatcher = ( # pytype: disable=annotation-type-mismatch
65166
self._add_child("input_dispatcher", cfg.input_dispatcher)
66167
)
168+
self._input_partitioner: Optional[InputPartitionFn] = maybe_instantiate(
169+
cfg.input_partitioner
170+
)
67171

68172
def dataset(self) -> Iterable[Nested[Tensor]]:
69173
"""Returns the input dataset, which should produce per-feed logical batches.
@@ -131,6 +235,9 @@ def dispatch_global_batch(
131235
The leaves of the output logical batch are partitioned across `batch_axis_names` along the
132236
0th (batch) dimension. This should be invoked from within `pjit` so that the sharding
133237
constraints can be applied.
238+
239+
If `cfg.input_partitioner` is not None, it will be applied to each logical batch after
240+
constraining `batch_axis_names`.
134241
"""
135242

136243
def constrain_batch_axis(batch):
@@ -147,7 +254,14 @@ def constrain_batch_axis(batch):
147254
global_logical_batch = dispatch_input_batch(
148255
global_physical_batch, batch_axis_names=batch_axis_names
149256
)
150-
return constrain_batch_axis(global_logical_batch)
257+
258+
global_logical_batch = constrain_batch_axis(global_logical_batch)
259+
260+
# Further constrain based on user-configured partitioning rules.
261+
if self._input_partitioner is not None:
262+
global_logical_batch = self._input_partitioner(global_logical_batch)
263+
264+
return global_logical_batch
151265

152266
def element_spec(self) -> Nested[jax.ShapeDtypeStruct]:
153267
"""Returns the per-feed logical batch spec.

0 commit comments

Comments
 (0)