2
2
3
3
"""Base Input interface."""
4
4
5
- from typing import Iterable , Iterator , Optional , Sequence , Union
5
+ import re
6
+ from typing import Iterable , Iterator , NamedTuple , Optional , Protocol , Sequence , Union
6
7
7
8
import jax
9
+ from absl import logging
10
+ from jax ._src .mesh import thread_resources
8
11
from jax .sharding import PartitionSpec
9
12
10
- from axlearn .common .config import config_class
13
+ from axlearn .common .config import ConfigOr , config_class , maybe_instantiate
11
14
from axlearn .common .input_dispatch import InputDispatcher
12
15
from axlearn .common .module import Module
13
16
from axlearn .common .utils import (
14
17
Nested ,
15
18
Tensor ,
16
19
as_numpy_array ,
17
20
dispatch_input_batch ,
21
+ tree_paths ,
18
22
with_sharding_constraint ,
19
23
)
20
24
21
25
26
+ class InputPartitionFn (Protocol ):
27
+ """Partitions the input batch."""
28
+
29
+ def __call__ (self , input_batch : Nested [Tensor ]) -> Nested [Tensor ]:
30
+ """Applies sharding constraints to `input_batch` and returns the modified batch.
31
+
32
+ Implementations should avoid making in-place updates to `input_batch`.
33
+ """
34
+
35
+
36
+ class PathAndRank (NamedTuple ):
37
+ """A tuple (path, rank) used for matching against inputs in a batch.
38
+
39
+ Attributes:
40
+ path: An optional path or path regex. None means match everything.
41
+ rank: An optional rank (ndim). None means match everything.
42
+ """
43
+
44
+ path : Optional [Union [str , re .Pattern ]]
45
+ rank : Optional [int ]
46
+
47
+
48
+ def partition_by_path_rank (
49
+ path_rank_to_partition : dict [PathAndRank , PartitionSpec ],
50
+ ) -> InputPartitionFn :
51
+ """Partitions the keys in the input batch by Tensor path and rank (ndim).
52
+
53
+ If not within a mesh, the partition fn is a no-op.
54
+
55
+ Args:
56
+ path_rank_to_partition: A mapping from (path_regex, rank) to partition spec.
57
+ For each input path, the Tensor will be constrained by the first matching
58
+ (path_regex, rank) rule, where paths are full-matched against `path_regex` and ranks are
59
+ matched against `rank`.
60
+ `path_regex` or `rank` are allowed to be None to match everything.
61
+ If replication is desired, specify a partition spec of None explicitly.
62
+ If leaving the input unconstrained is desired, specify a partition spec of
63
+ `PartitionSpec.UNCONSTRAINED` explicitly.
64
+
65
+ Returns:
66
+ A function that applies sharding constraints to an input batch and returns a new batch.
67
+
68
+ Raises:
69
+ ValueError: If no rules match for a given input, which is likely an oversight. If leaving
70
+ inputs unconstrained is desired, explicitly specify `PartitionSpec.UNCONSTRAINED`.
71
+
72
+ Example:
73
+ To constrain all rank-1 Tensors by ("data",) and rank-2 by ("data", "seq"):
74
+ ```
75
+ partition_by_path_ndim({
76
+ (".*", 1): PartitionSpec("data"),
77
+ (".*", 2): PartitionSpec("data", "seq"),
78
+ })
79
+ ```
80
+ """
81
+ compiled = {}
82
+ for (regex , rank ), spec in path_rank_to_partition .items ():
83
+ if regex is not None :
84
+ regex = re .compile (regex )
85
+ compiled [(regex , rank )] = spec
86
+
87
+ def fn (input_batch : Nested [Tensor ]) -> Nested [Tensor ]:
88
+ mesh = thread_resources .env .physical_mesh # type: ignore
89
+ if mesh .empty or mesh .size == 1 :
90
+ return input_batch
91
+
92
+ def maybe_constrain (path : str , value : Tensor ):
93
+ for (path_regex , rank ), partition_spec in compiled .items ():
94
+ if not (rank is None or value .ndim == rank ) or not (
95
+ path_regex is None or re .fullmatch (path_regex , path )
96
+ ):
97
+ continue
98
+ if partition_spec is not PartitionSpec .UNCONSTRAINED :
99
+ value = with_sharding_constraint (value , partition_spec )
100
+ logging .log_first_n (
101
+ logging .INFO ,
102
+ "Constraining input_batch[%s] with %s." ,
103
+ len (input_batch ),
104
+ path ,
105
+ partition_spec ,
106
+ )
107
+ return value
108
+ # No rules match. We raise as not-constraining is likely an oversight.
109
+ raise ValueError (
110
+ f"No rules matched input_batch['{ path } ']. "
111
+ "If you intended to leave the input unconstrained, "
112
+ "specify `PartitionSpec.UNCONSTRAINED` explicitly."
113
+ )
114
+
115
+ return jax .tree_map (maybe_constrain , tree_paths (input_batch ), input_batch )
116
+
117
+ return fn
118
+
119
+
22
120
class Input (Module ):
23
121
"""A Module to generate input batches.
24
122
@@ -53,9 +151,12 @@ class Config(Module.Config):
53
151
Attributes:
54
152
input_dispatcher: If not None, creates an InputDispatcher and uses it for dispatching
55
153
per-feed batches to global batches.
154
+ input_partitioner: If not None, applies additional sharding constraints on each input
155
+ batch during `dispatch_global_batch`.
56
156
"""
57
157
58
158
input_dispatcher : Optional [InputDispatcher .Config ] = None
159
+ input_partitioner : Optional [ConfigOr [InputPartitionFn ]] = None
59
160
60
161
def __init__ (self , cfg : Config , * , parent : Optional [Module ]):
61
162
super ().__init__ (cfg , parent = parent )
@@ -64,6 +165,9 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
64
165
self .input_dispatcher : InputDispatcher = ( # pytype: disable=annotation-type-mismatch
65
166
self ._add_child ("input_dispatcher" , cfg .input_dispatcher )
66
167
)
168
+ self ._input_partitioner : Optional [InputPartitionFn ] = maybe_instantiate (
169
+ cfg .input_partitioner
170
+ )
67
171
68
172
def dataset (self ) -> Iterable [Nested [Tensor ]]:
69
173
"""Returns the input dataset, which should produce per-feed logical batches.
@@ -131,6 +235,9 @@ def dispatch_global_batch(
131
235
The leaves of the output logical batch are partitioned across `batch_axis_names` along the
132
236
0th (batch) dimension. This should be invoked from within `pjit` so that the sharding
133
237
constraints can be applied.
238
+
239
+ If `cfg.input_partitioner` is not None, it will be applied to each logical batch after
240
+ constraining `batch_axis_names`.
134
241
"""
135
242
136
243
def constrain_batch_axis (batch ):
@@ -147,7 +254,14 @@ def constrain_batch_axis(batch):
147
254
global_logical_batch = dispatch_input_batch (
148
255
global_physical_batch , batch_axis_names = batch_axis_names
149
256
)
150
- return constrain_batch_axis (global_logical_batch )
257
+
258
+ global_logical_batch = constrain_batch_axis (global_logical_batch )
259
+
260
+ # Further constrain based on user-configured partitioning rules.
261
+ if self ._input_partitioner is not None :
262
+ global_logical_batch = self ._input_partitioner (global_logical_batch )
263
+
264
+ return global_logical_batch
151
265
152
266
def element_spec (self ) -> Nested [jax .ShapeDtypeStruct ]:
153
267
"""Returns the per-feed logical batch spec.
0 commit comments