Skip to content

Commit dcf83c7

Browse files
committed
Make Linen and NNX logical rule deduction align
1 parent 771eadb commit dcf83c7

File tree

4 files changed

+130
-125
lines changed

4 files changed

+130
-125
lines changed

flax/core/spmd.py

Lines changed: 118 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import collections
1516
import contextlib
1617
import dataclasses
1718
import threading
19+
import typing as tp
1820

1921
import jax
2022
from jax.sharding import PartitionSpec, NamedSharding
2123
from flax.core import meta
2224
from flax.typing import (
2325
LogicalRules,
24-
Sharding,
2526
)
2627

2728
def get_pspec(sharding_names, sharding_rules = None) -> PartitionSpec:
2829
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
2930
if get_logical_axis_rules() or sharding_rules:
30-
context_rules = get_logical_axis_rules()
31-
rules = composite_rules(context_rules, sharding_rules)
32-
return PartitionSpec(*from_sharding_rules(sharding_names, rules))
31+
sharding_names = logical_to_mesh_axes(sharding_names, sharding_rules)
3332
return PartitionSpec(*sharding_names)
3433

3534

@@ -105,10 +104,119 @@ def composite_rules(rule1, rule2):
105104
return tuple(rules.items())
106105

107106

108-
def from_sharding_rules(
109-
sharding: Sharding, sharding_rules: LogicalRules
110-
) -> Sharding:
111-
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
112-
return tuple(
113-
rules[str(s)] if (s and str(s) in rules) else s for s in sharding
107+
108+
class _UnassignedAxis:
109+
"""Sentinel class for unassigned logical axis name."""
110+
111+
def __repr__(self):
112+
return 'UnassignedAxis'
113+
114+
def __bool__(self):
115+
return False
116+
117+
118+
_unassigned_axis = _UnassignedAxis()
119+
120+
121+
def _mesh_assignment_free(new_assignment, existing_assignments):
122+
"""Determines if a given mesh axis has already been assigned."""
123+
new = set(jax.tree_util.tree_leaves(new_assignment))
124+
existing = set(jax.tree_util.tree_leaves(existing_assignments))
125+
if existing.intersection(new):
126+
return False
127+
return True
128+
129+
130+
def _logical_to_mesh_axes(
131+
array_dim_names: tp.Sequence[str | None] | None,
132+
rules: LogicalRules | None = None,
133+
) -> list[_UnassignedAxis | None | str | tuple[str, ...]] | None:
134+
"""Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis."""
135+
if array_dim_names is None:
136+
return None
137+
if rules is None:
138+
rules = get_logical_axis_rules()
139+
axis_name_counts = collections.Counter(array_dim_names)
140+
# None and special values such as PartitionSpec.UNCONSTRAINED can appear more
141+
# then once.
142+
dups = tuple(
143+
k for k, v in axis_name_counts.items() if v > 1 and isinstance(k, str)
114144
)
145+
if dups:
146+
raise ValueError(
147+
f'Unsupported: Dimensions {dups} occur more than once in array names.'
148+
)
149+
if not isinstance(rules, (tuple, list)):
150+
raise ValueError('Unknown axis rule specification type.')
151+
# We assign mesh axes using a priority based ruleset over logical axis names.
152+
result: list[_UnassignedAxis | None | str | tuple[str, ...]]
153+
result = [
154+
(_unassigned_axis if isinstance(name, str) else name)
155+
for name in array_dim_names
156+
]
157+
for rule_model_name, rule_mesh_names in rules:
158+
if rule_model_name in array_dim_names:
159+
pos = array_dim_names.index(rule_model_name)
160+
if (
161+
_mesh_assignment_free(rule_mesh_names, result)
162+
and result[pos] == _unassigned_axis
163+
):
164+
result[pos] = rule_mesh_names
165+
return result
166+
167+
168+
def logical_to_mesh_axes(
169+
array_dim_names: tp.Sequence[str | None] | None,
170+
rules: LogicalRules | None = None,
171+
) -> jax.sharding.PartitionSpec | None:
172+
"""Compute layout for an array.
173+
174+
The rules are in order of precedence, and consist of pairs:
175+
``(ArrayDimensionName, MeshDimensionName)``, meaning that the given array
176+
dimension (if present and unused) should be sharded across the given
177+
mesh dimension (if present and unused).
178+
179+
A Layout of an Array is expressed as a tuple with one element for each
180+
dimension in the Array. The element is either None, or is the name of a
181+
mesh-dimension, meaning that this dimension of the array is sharded across
182+
this dimension of the mesh.
183+
184+
For example, given an array with::
185+
186+
array_dim_names = ('batch', 'length', 'heads', 'features')
187+
188+
and the layout rules are::
189+
190+
rules = (('batch', 'X'),
191+
('features', 'X'),
192+
('heads', 'Y'),
193+
('batch', 'Z'))
194+
195+
then this function will return::
196+
197+
PartitionSpec('X', None, 'Y', None)
198+
199+
Args:
200+
array_dim_names: Tuple of array dimension names or None.
201+
rules: Optional logical to mesh rules override. Defaults to using the
202+
rules defined in the dynamic context set from the ``axis_rules`` function.
203+
204+
Returns:
205+
PartitionSpec for the parameter.
206+
"""
207+
result = _logical_to_mesh_axes(array_dim_names, rules)
208+
if result is None:
209+
return None
210+
# We default to None - ie unsharded along the dimension.
211+
result = [None if x is _unassigned_axis else x for x in result]
212+
return jax.sharding.PartitionSpec(*result)
213+
214+
215+
216+
# def from_sharding_rules(
217+
# sharding_names: Sharding, sharding_rules: LogicalRules
218+
# ) -> Sharding:
219+
# rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
220+
# return tuple(
221+
# rules[str(s)] if (s and str(s) in rules) else s for s in sharding_names
222+
# )

flax/linen/spmd.py

Lines changed: 4 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,22 @@
2424
introducing logical axis metadata into a model's variables.
2525
"""
2626

27-
import collections
2827
import dataclasses
2928
import enum
3029
import functools
3130
from typing import Any
32-
from collections.abc import Callable, Sequence
31+
from collections.abc import Callable
3332

3433
import jax
3534
from jax import lax
3635

3736
from flax import struct
3837
from flax.core import meta
3938
from flax.core.spmd import (
39+
_logical_to_mesh_axes,
40+
_unassigned_axis,
4041
get_logical_axis_rules,
42+
logical_to_mesh_axes,
4143
)
4244
from flax.typing import (
4345
Array,
@@ -49,111 +51,7 @@
4951
)
5052

5153

52-
class _UnassignedAxis:
53-
"""Sentinel class for unassigned logical axis name."""
5454

55-
def __repr__(self):
56-
return 'UnassignedAxis'
57-
58-
def __bool__(self):
59-
return False
60-
61-
62-
_unassigned_axis = _UnassignedAxis()
63-
64-
65-
def _mesh_assignment_free(new_assignment, existing_assignments):
66-
"""Determines if a given mesh axis has already been assigned."""
67-
new = set(jax.tree_util.tree_leaves(new_assignment))
68-
existing = set(jax.tree_util.tree_leaves(existing_assignments))
69-
if existing.intersection(new):
70-
return False
71-
return True
72-
73-
74-
def _logical_to_mesh_axes(
75-
array_dim_names: Sequence[str | None] | None,
76-
rules: LogicalRules | None = None,
77-
) -> list[_UnassignedAxis | None | str | tuple[str, ...]] | None:
78-
"""Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis."""
79-
if array_dim_names is None:
80-
return None
81-
if rules is None:
82-
rules = get_logical_axis_rules()
83-
axis_name_counts = collections.Counter(array_dim_names)
84-
# None and special values such as PartitionSpec.UNCONSTRAINED can appear more
85-
# then once.
86-
dups = tuple(
87-
k for k, v in axis_name_counts.items() if v > 1 and isinstance(k, str)
88-
)
89-
if dups:
90-
raise ValueError(
91-
f'Unsupported: Dimensions {dups} occur more than once in array names.'
92-
)
93-
if not isinstance(rules, (tuple, list)):
94-
raise ValueError('Unknown axis rule specification type.')
95-
# We assign mesh axes using a priority based ruleset over logical axis names.
96-
result: list[_UnassignedAxis | None | str | tuple[str, ...]]
97-
result = [
98-
(_unassigned_axis if isinstance(name, str) else name)
99-
for name in array_dim_names
100-
]
101-
for rule_model_name, rule_mesh_names in rules:
102-
if rule_model_name in array_dim_names:
103-
pos = array_dim_names.index(rule_model_name)
104-
if (
105-
_mesh_assignment_free(rule_mesh_names, result)
106-
and result[pos] == _unassigned_axis
107-
):
108-
result[pos] = rule_mesh_names
109-
return result
110-
111-
112-
def logical_to_mesh_axes(
113-
array_dim_names: Sequence[str | None] | None,
114-
rules: LogicalRules | None = None,
115-
) -> jax.sharding.PartitionSpec | None:
116-
"""Compute layout for an array.
117-
118-
The rules are in order of precedence, and consist of pairs:
119-
``(ArrayDimensionName, MeshDimensionName)``, meaning that the given array
120-
dimension (if present and unused) should be sharded across the given
121-
mesh dimension (if present and unused).
122-
123-
A Layout of an Array is expressed as a tuple with one element for each
124-
dimension in the Array. The element is either None, or is the name of a
125-
mesh-dimension, meaning that this dimension of the array is sharded across
126-
this dimension of the mesh.
127-
128-
For example, given an array with::
129-
130-
array_dim_names = ('batch', 'length', 'heads', 'features')
131-
132-
and the layout rules are::
133-
134-
rules = (('batch', 'X'),
135-
('features', 'X'),
136-
('heads', 'Y'),
137-
('batch', 'Z'))
138-
139-
then this function will return::
140-
141-
PartitionSpec('X', None, 'Y', None)
142-
143-
Args:
144-
array_dim_names: Tuple of array dimension names or None.
145-
rules: Optional logical to mesh rules override. Defaults to using the
146-
rules defined in the dynamic context set from the ``axis_rules`` function.
147-
148-
Returns:
149-
PartitionSpec for the parameter.
150-
"""
151-
result = _logical_to_mesh_axes(array_dim_names, rules)
152-
if result is None:
153-
return None
154-
# We default to None - ie unsharded along the dimension.
155-
result = [None if x is _unassigned_axis else x for x in result]
156-
return jax.sharding.PartitionSpec(*result)
15755

15856

15957
def logical_to_mesh(tree: Any, rules: LogicalRules | None = None) -> Any:

flax/nnx/spmd.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,11 @@ def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None:
129129
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
130130
metadata = v.get_metadata()
131131
if 'sharding_names' in metadata and metadata['sharding_names']:
132-
sharding = metadata['sharding_names']
132+
sharding_names = metadata['sharding_names']
133133
if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata:
134-
context_rules = core_spmd.get_logical_axis_rules()
135-
local_rules = metadata.get('sharding_rules', ())
136-
rules = core_spmd.composite_rules(context_rules, local_rules)
137-
return PartitionSpec(*core_spmd.from_sharding_rules(sharding, rules))
138-
return PartitionSpec(*sharding)
134+
sharding_names = core_spmd.logical_to_mesh_axes(
135+
sharding_names, metadata.get('sharding_rules', None))
136+
return PartitionSpec(*sharding_names)
139137
elif hasattr(v, 'shape'):
140138
return PartitionSpec()
141139
return None

tests/nnx/spmd_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,20 +185,21 @@ def __init__(self):
185185
nnx.with_partitioning(
186186
lambda: jnp.ones((8, 2)),
187187
sharding=('row-alias', 'col-alias'),
188-
sharding_rules=(('row-alias', 'row'),),
189188
)()
190189
)
191190
self.b = nnx.Param(
192191
nnx.with_partitioning(
193-
lambda: jnp.zeros((2,)), sharding=('col-alias',)
192+
lambda: jnp.zeros((2,)), sharding=('col-alias2',),
193+
sharding_rules=(('col-alias2', 'col'),),
194194
)()
195195
)
196196

197197
def __call__(self, x):
198198
return x @ self.w + self.b
199199

200200
mesh = jax.make_mesh(((1, 2, 2)), ('layers', 'row', 'col'))
201-
with jax.set_mesh(mesh), nnx.logical_axis_rules((('col-alias', 'col'),)):
201+
global_rule = (('row-alias', 'row'),('col-alias', 'col'),)
202+
with jax.set_mesh(mesh), nnx.logical_axis_rules(global_rule):
202203
model = Foo()
203204
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
204205

0 commit comments

Comments
 (0)