Skip to content

Commit 42fd715

Browse files
authored
TRN2 Meshes and Configurations (apple#916)
* TRN2 Meshes and Configurations * Add get_recursive and set_recursive to ConfigBase. * Use loops inside get/set_recursively + address comments * Update partition spec * Use get_recursively inside set * Move trn2 configs to a helper function. + Fix modifier tests * TRN2 partitionspec supports DP over FSDP and TP * Use for loop in get_recursively * Update Golden Configs
1 parent d47d5ce commit 42fd715

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+4437
-17
lines changed

axlearn/common/config.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class Config(ConfigBase):
7070
from collections import defaultdict
7171
from collections.abc import Collection, Iterable
7272
from functools import cache
73-
from typing import Any, Callable, Generic, Optional, TypeVar, Union
73+
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union
7474

7575
# attr provides similar features as Python dataclass. Unlike
7676
# dataclass, however, it provides a richer set of features to regulate
@@ -394,6 +394,42 @@ def set(self, **kwargs):
394394
setattr(self, k, v)
395395
return self
396396

397+
def get_recursively(self, path: Sequence[str]) -> Any:
398+
"""Recursively find the target key in the config and return its value.
399+
400+
Args:
401+
path: A sequence of keys for indexing to get the target value.
402+
403+
Raises:
404+
AttributeError: If key in path is not found.
405+
406+
Returns:
407+
value at the path or self if path is empty.
408+
"""
409+
current = self
410+
411+
for key in path:
412+
# TODO(markblee): Maybe use cfg.visit instead of getattr.
413+
current = getattr(current, key)
414+
415+
return current
416+
417+
def set_recursively(self, path: Sequence[str], *, value: Any):
418+
"""Recursively find the target key in the config and set its value.
419+
420+
Args:
421+
path: A sequence of keys for indexing to set the target value.
422+
new_value: New value to replace the target value.
423+
424+
Raises:
425+
ValueError: if Path is empty.
426+
AttributeError: If key in path is not found.
427+
"""
428+
if not path:
429+
raise ValueError("Path is empty.")
430+
parent = self.get_recursively(path[:-1])
431+
setattr(parent, path[-1], value)
432+
397433
def clone(self, **kwargs):
398434
"""Returns a clone of the original config with the optional keyword overrides.
399435

axlearn/common/config_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,71 @@ def set(self, **kwargs):
934934
self.assertEqual(123, cfg_clone.a)
935935
self.assertEqual("default", cfg_clone.b)
936936

937+
def test_get_recursively(self):
938+
class Nested(Configurable):
939+
@config_class
940+
class Config(Configurable.Config):
941+
"""A dummy config."""
942+
943+
value: int = 0
944+
945+
class Test(Configurable):
946+
@config_class
947+
class Config(Configurable.Config):
948+
"""Another dummy config that has a nested config."""
949+
950+
nested: Nested.Config = Nested.default_config()
951+
value: int = 1
952+
953+
cfg = Test.default_config()
954+
955+
# Test getting nested value.
956+
self.assertEqual(cfg.get_recursively(["nested", "value"]), 0)
957+
958+
# Test getting top-level value.
959+
self.assertEqual(cfg.get_recursively(["value"]), 1)
960+
961+
# Test getting non-existent value.
962+
with self.assertRaises(AttributeError):
963+
cfg.get_recursively(["non_existent"])
964+
965+
# Test getting empty path, should return self.
966+
self.assertEqual(cfg.get_recursively([]), cfg)
967+
968+
def test_set_recursively(self):
969+
class Nested(Configurable):
970+
@config_class
971+
class Config(Configurable.Config):
972+
"""A dummy config."""
973+
974+
value: int = 0
975+
976+
class Test(Configurable):
977+
@config_class
978+
class Config(Configurable.Config):
979+
"""Another dummy config that has a nested config."""
980+
981+
nested: Nested.Config = Nested.default_config()
982+
value: int = 1
983+
984+
cfg = Test.default_config()
985+
986+
# Test setting nested value.
987+
cfg.set_recursively(["nested", "value"], value=10)
988+
self.assertEqual(cfg.nested.value, 10)
989+
990+
# Test setting top-level value.
991+
cfg.set_recursively(["value"], value=5)
992+
self.assertEqual(cfg.value, 5)
993+
994+
# Test setting non-existent value.
995+
with self.assertRaises(AttributeError):
996+
cfg.set_recursively(["non_existent"], value=20)
997+
998+
# Test setting empty path.
999+
with self.assertRaises(ValueError):
1000+
cfg.set_recursively([], value=20)
1001+
9371002

9381003
if __name__ == "__main__":
9391004
absltest.main()

axlearn/common/trainer_config_modifier.py

Lines changed: 111 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
REQUIRED,
1111
ConfigModifier,
1212
ConfigOr,
13+
Configurable,
1314
Required,
1415
config_class,
1516
maybe_instantiate,
1617
)
1718
from axlearn.common.gradient_accumulation import with_minibatch_steps
1819
from axlearn.common.metrics import MetricAccumulator
1920
from axlearn.common.trainer import SpmdTrainer
20-
from axlearn.common.utils import HybridMeshShape, MeshShape
21+
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec
2122

2223

2324
class GradientAccumulationModifier(ConfigModifier):
@@ -100,18 +101,8 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
100101
"""
101102

102103
for module_name, remat_spec in self._remat_policies.items():
103-
# Here we assume x.y.z format.
104-
# One example would be model.decoder.transformer.layer.
105-
target_modules = module_name.split(".")
106-
curr_module = cfg
107-
for target_module in target_modules:
108-
if not hasattr(curr_module, target_module):
109-
raise ValueError(f"{target_module} is not found in {curr_module}.")
110-
curr_module = getattr(curr_module, target_module)
111-
# Here we assume all modules have remat_spec attribute.
112-
if not hasattr(curr_module, "remat_spec"):
113-
raise ValueError(f"{curr_module} does not have remat_spec attribute")
114-
curr_module.remat_spec = remat_spec
104+
cfg.set_recursively(module_name.split(".") + ["remat_spec"], value=remat_spec)
105+
115106
return cfg
116107

117108

@@ -146,6 +137,113 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
146137
return cfg
147138

148139

140+
class ModuleConfigModifier(ConfigModifier):
141+
"""Update the model config for the trainer config."""
142+
143+
@config_class
144+
class Config(ConfigModifier.Config):
145+
"""Configure ModuleConfigModifier.
146+
147+
Attributes:
148+
target_config: Target module path
149+
(e.g. `model.decoder.transformer.layer`) to be modified.
150+
modification: The new config to replace the target module's config.
151+
"""
152+
153+
target_config: Required[str] = REQUIRED
154+
modification: Required[Configurable.Config] = REQUIRED
155+
156+
def __init__(self, cfg: Config):
157+
super().__init__(cfg)
158+
self._target_config = self.config.target_config
159+
self._modification = self.config.modification
160+
161+
def _merge_configs(
162+
self, target_cfg: Configurable.Config, found_module: Configurable.Config
163+
) -> Configurable.Config:
164+
"""Merge configurations from the config being replaced on a best effort basis.
165+
166+
Merge Rules:
167+
- Klass is not changed, use target cfg.
168+
- If field exists in both then use from class being replaced.
169+
- Otherwise keep the value from target_cfg.
170+
171+
Args:
172+
target_cfg: Configuration that will replace found_module.
173+
found_module: Existing configuration whose class will be replaced
174+
but it's confguration will be merged with target_cfg.
175+
176+
Returns:
177+
The modified config.
178+
179+
"""
180+
for key in target_cfg.keys():
181+
if key == "klass":
182+
continue
183+
elif hasattr(found_module, key) and hasattr(target_cfg, key):
184+
setattr(target_cfg, key, getattr(found_module, key))
185+
return target_cfg
186+
187+
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
188+
"""Overwrite the model config of the specified modules.
189+
190+
Args:
191+
cfg: The trainer config to be modified.
192+
193+
Raises:
194+
ValueError: The target module is not found.
195+
196+
Returns:
197+
The modified trainer config.
198+
"""
199+
200+
found_module = cfg.get_recursively(self._target_config.split("."))
201+
self._modification = self._merge_configs(self._modification, found_module)
202+
cfg.set_recursively(self._target_config.split("."), value=self._modification)
203+
return cfg
204+
205+
206+
class PartitionSpecModifier(ConfigModifier):
207+
"""Update the partition spec attribute for the specified modules."""
208+
209+
@config_class
210+
class Config(ConfigModifier.Config):
211+
"""Configure PartitionSpecModifier.
212+
213+
Attributes:
214+
partition_specs: A nested mapping from module path
215+
(e.g. `model.decoder.transformer.layer`) to another
216+
mapping of model attribute to PartitionSpec.
217+
"""
218+
219+
partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED
220+
221+
def __init__(self, cfg: Config):
222+
super().__init__(cfg)
223+
self._attribute_dicts = self.config.partition_specs
224+
225+
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
226+
"""Update the partition_spec attributes for the specified modules.
227+
228+
Args:
229+
cfg: The trainer config to be modified.
230+
231+
Raises:
232+
ValueError: The target module is not found.
233+
ValueError: The partition_spec attribute is not found.
234+
235+
Returns:
236+
The modified trainer config.
237+
"""
238+
for module_name, partition_spec_dict in self._attribute_dicts.items():
239+
for partition_spec_name, partition_spec in partition_spec_dict.items():
240+
cfg.set_recursively(
241+
module_name.split(".") + [partition_spec_name], value=partition_spec
242+
)
243+
244+
return cfg
245+
246+
149247
class ChainConfigModifier(ConfigModifier):
150248
"""Chain multiple config modifiers together."""
151249

axlearn/common/trainer_config_modifier_test.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import jax
66
from absl.testing import absltest
77

8-
from axlearn.common import test_utils
8+
from axlearn.common import causal_lm, test_utils
9+
from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer
910
from axlearn.common.base_layer import RematSpec
1011
from axlearn.common.trainer import SpmdTrainer
1112
from axlearn.common.trainer_config_modifier import (
1213
ChainConfigModifier,
1314
GradientAccumulationModifier,
1415
MeshShapeModifier,
16+
ModuleConfigModifier,
17+
PartitionSpecModifier,
1518
RematSpecModifier,
1619
)
1720
from axlearn.common.trainer_test import DummyModel
@@ -61,7 +64,87 @@ def test_remat_policy_override(self):
6164
.instantiate()
6265
)
6366
# Ensure that the exception is working.
64-
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
67+
with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"):
68+
_ = cfg_modifier(cfg)
69+
70+
71+
class ModuleConfigModifierTest(test_utils.TestCase):
72+
def test_model_config_override(self):
73+
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
74+
self.assertTrue(
75+
str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config())
76+
)
77+
78+
cfg_modifier = (
79+
ModuleConfigModifier.default_config()
80+
.set(
81+
target_config="model.decoder.transformer",
82+
modification=RepeatedTransformerLayer.default_config(),
83+
)
84+
.instantiate()
85+
)
86+
87+
cfg = cfg_modifier(cfg)
88+
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
89+
self.assertTrue(
90+
str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config())
91+
)
92+
cfg_modifier = (
93+
ModuleConfigModifier.default_config()
94+
.set(
95+
target_config="model.decoder.unknown",
96+
modification=RepeatedTransformerLayer.default_config(),
97+
)
98+
.instantiate()
99+
)
100+
# Ensure that the exception is working.
101+
with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"):
102+
_ = cfg_modifier(cfg)
103+
104+
105+
class PartitionSpecModifierTest(test_utils.TestCase):
106+
def test_partition_spec_override(self):
107+
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
108+
cfg_modifier = (
109+
PartitionSpecModifier.default_config()
110+
.set(
111+
partition_specs={
112+
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
113+
},
114+
)
115+
.instantiate()
116+
)
117+
cfg = cfg_modifier(cfg)
118+
self.assertTrue(
119+
str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")"""
120+
)
121+
cfg_modifier = (
122+
PartitionSpecModifier.default_config()
123+
.set(
124+
partition_specs={
125+
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
126+
"model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
127+
},
128+
)
129+
.instantiate()
130+
)
131+
# Ensure that the exception is working.
132+
with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"):
133+
_ = cfg_modifier(cfg)
134+
135+
cfg_modifier = (
136+
PartitionSpecModifier.default_config()
137+
.set(
138+
partition_specs={
139+
"model.linear": {
140+
"param_partition_spec": ("model", ("expert", "fsdp", "seq")),
141+
"unknown_partition_spec": ("model", ("expert", "fsdp", "seq")),
142+
},
143+
},
144+
)
145+
.instantiate()
146+
)
147+
with self.assertRaisesRegex(AttributeError, "unknown_partition_spec *"):
65148
_ = cfg_modifier(cfg)
66149

67150

0 commit comments

Comments
 (0)