36
36
import typing_extensions
37
37
from absl import logging
38
38
from jax import numpy as jnp
39
+ from jax ._src .sharding_impls import TransferToMemoryKind
39
40
from optax ._src import numerics
40
41
41
42
from axlearn .common import schedule , struct
42
- from axlearn .common .base_layer import NestedParameterSpec , ParameterSpec , PartitionSpec
43
+ from axlearn .common .base_layer import ParameterSpec , PartitionSpec
43
44
from axlearn .common .config import ConfigOr , maybe_instantiate
44
45
from axlearn .common .factorized_rms import scale_by_factored_rms
45
46
from axlearn .common .module import current_context
51
52
TransformPartitionSpecFn ,
52
53
)
53
54
from axlearn .common .utils import (
55
+ MemoryKind ,
54
56
Nested ,
55
- NestedPartitionSpec ,
56
57
NestedTensor ,
57
58
NestedTree ,
58
59
Tensor ,
@@ -139,19 +140,41 @@ def update_fn(
139
140
return PartitionedGradientTransformation (init = init_fn , update = update_fn , partition = partition_fn )
140
141
141
142
142
- def copy_partition (param_specs : NestedParameterSpec ) -> NestedPartitionSpec :
143
+ def copy_partition (
144
+ specs : Nested [OptStateSpec ],
145
+ * ,
146
+ pattern : Union [None , str , re .Pattern ] = None ,
147
+ memory_kind : Optional [MemoryKind ] = None ,
148
+ ) -> Nested [OptStateSpec ]:
149
+ """Copies OptStateSpec and optionally assigns with a different memory kind.
150
+
151
+ Args:
152
+ specs: Nested[OptStateSpec] to copy from.
153
+ pattern: Regex to match the full path of each spec. Matched specs will have their memory
154
+ kind replaced with `memory_kind`.
155
+ memory_kind: New memory kind. Default to None.
156
+
157
+ Returns:
158
+ A Nested[OptStateSpec] with possibly a different memory kind.
159
+ """
143
160
return jax .tree .map (
144
- lambda param_spec : OptStateSpec (
145
- dtype = param_spec .dtype , shape = param_spec .shape , mesh_axes = param_spec .mesh_axes
161
+ lambda path , spec : OptStateSpec (
162
+ dtype = spec .dtype ,
163
+ shape = spec .shape ,
164
+ mesh_axes = spec .mesh_axes ,
165
+ memory_kind = memory_kind
166
+ if pattern and re .fullmatch (pattern , path )
167
+ else spec .memory_kind ,
146
168
),
147
- param_specs ,
169
+ tree_paths (specs ),
170
+ specs ,
148
171
)
149
172
150
173
151
174
def trace_partition (
152
175
base : optax .GradientTransformation ,
153
176
) -> PartitionedGradientTransformation :
154
- def partition_fn (param_specs : NestedParameterSpec ) -> NestedPartitionSpec :
177
+ def partition_fn (param_specs : Nested [ ParameterSpec ] ) -> Nested [ OptStateSpec ] :
155
178
return optax .TraceState (trace = copy_partition (param_specs ))
156
179
157
180
return with_partition_fn (base , partition_fn )
@@ -160,7 +183,9 @@ def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
160
183
def adam_partition (base : optax .GradientTransformation ) -> PartitionedGradientTransformation :
161
184
state : optax .ScaleByAdamState = base .init ({})
162
185
163
- def partition_fn (param_specs : NestedParameterSpec ) -> NestedPartitionSpec :
186
+ def partition_fn (
187
+ param_specs : Nested [ParameterSpec ],
188
+ ) -> Nested [Union [OptStateSpec , optax .ScaleByAdamState ]]:
164
189
return optax .ScaleByAdamState (
165
190
count = OptStateSpec (
166
191
dtype = state .count .dtype , shape = state .count .shape , mesh_axes = PartitionSpec ()
@@ -950,7 +975,7 @@ def _update(value: Tensor, ema: Tensor, qstep_size: Tensor, count: Tensor) -> _U
950
975
)
951
976
return updates , new_state
952
977
953
- def partition_fn (param_specs : NestedParameterSpec ) -> NestedPartitionSpec :
978
+ def partition_fn (param_specs : Nested [ ParameterSpec ] ) -> Nested [ Union [ OptStateSpec , EmaState ]] :
954
979
def get_ema_partition (param_spec : ParameterSpec ) -> OptStateSpec :
955
980
# Store momentum in accumulator_dtype if it is set and p is not scalar.
956
981
if param_spec .shape and accumulator_dtype is not None :
@@ -1412,7 +1437,9 @@ def _is_valid_step(
1412
1437
drop_stats = new_drop_stats ,
1413
1438
)
1414
1439
1415
- def partition_fn (param_specs : NestedParameterSpec ) -> NestedPartitionSpec :
1440
+ def partition_fn (
1441
+ param_specs : Nested [ParameterSpec ],
1442
+ ) -> Nested [Union [OptStateSpec , SkipClipState ]]:
1416
1443
if use_adaptive_drop_norm :
1417
1444
one = jnp .ones ([], jnp .float32 )
1418
1445
dict_thresholds = drop_norm (count = one , mean = one , stddev = one )
@@ -1571,7 +1598,9 @@ def update_fn(updates, state, params):
1571
1598
)
1572
1599
return updates , ParamEmaState (count = count_inc , ema = new_ema )
1573
1600
1574
- def partition_fn (param_specs : NestedParameterSpec ) -> NestedPartitionSpec :
1601
+ def partition_fn (
1602
+ param_specs : Nested [ParameterSpec ],
1603
+ ) -> Nested [Union [OptStateSpec , ParamEmaState ]]:
1575
1604
return ParamEmaState (
1576
1605
count = OptStateSpec (dtype = jnp .int32 , shape = [], mesh_axes = PartitionSpec ()),
1577
1606
ema = copy_partition (param_specs ),
@@ -1617,7 +1646,9 @@ def update_fn(updates, state, params=None):
1617
1646
updates = jax .tree .map (lambda g , m : jnp .sign ((1.0 - b1 ) * g + b1 * m ), updates , state .mu )
1618
1647
return updates , ScaleByLionState (count = count_inc , mu = mu )
1619
1648
1620
- def partition_fn (param_specs : NestedParameterSpec ) -> NestedPartitionSpec :
1649
+ def partition_fn (
1650
+ param_specs : Nested [ParameterSpec ],
1651
+ ) -> Nested [Union [OptStateSpec , ScaleByLionState ]]:
1621
1652
mu_specs = param_specs
1622
1653
if mu_dtype is not None :
1623
1654
mu_specs = jax .tree .map (
@@ -1993,3 +2024,100 @@ def _update2(u: Tensor, param: OptParam):
1993
2024
partition = lambda _ : OptStateSpec (shape = [], dtype = jnp .int32 , mesh_axes = PartitionSpec ()),
1994
2025
)
1995
2026
return named_chain (** tx )
2027
+
2028
+
2029
+ def offload_optimizer (
2030
+ optimizer : ConfigOr [PartitionedGradientTransformation ],
2031
+ * ,
2032
+ pattern : Union [str , re .Pattern ] = ".*" ,
2033
+ offload_src : MemoryKind = "device" ,
2034
+ offload_dst : MemoryKind = "pinned_host" ,
2035
+ ) -> PartitionedGradientTransformation :
2036
+ """Offload the state of the wrapped optimizer that matches `pattern` to `offload_dst`.
2037
+
2038
+ Args:
2039
+ optimizer: The optimizer to offload.
2040
+ pattern: Regex pattern used to match the path of optimizer states. Fully matched states
2041
+ will be offloaded. Default to regex that matches all states.
2042
+ offload_src: Offload-from memory kind. Default to "device".
2043
+ offload_dst: Offload-to memory kind. Default to "pinned_host".
2044
+
2045
+ Returns:
2046
+ A optimizer whose state is on `offload_dst` and does the same computation as `optimizer`.
2047
+
2048
+ Raises:
2049
+ ValueError: when the `update` function of the returned optimizer is called outside of jit
2050
+ context.
2051
+
2052
+ This function returns a new `PartitionedGradientTransformation` that
2053
+ 1. Puts matched states of the wrapped optimizer on `offload_dst` through the partition function
2054
+ during state initialization in the trainer.
2055
+ 2. Copies the matched states to `offload_src` before `optimizer.update` is called.
2056
+ 3. Copies the matched updated states to `offload_dst` after `optimizer.update` is called.
2057
+
2058
+ The regex pattern is matched against the full path of each optimizer state. An example full
2059
+ path is optimizer/1/0/mu/decoder/transformer/repeat/layer/feed_forward/linear1_0. If the
2060
+ pattern should not depend on model structure, you can use ".*/mu/.*" to offload all `mu`.
2061
+
2062
+ The .update function of the returned `PartitionedGradientTransformation` must be called within
2063
+ a jit function.
2064
+
2065
+ Example usage:
2066
+ ```python
2067
+ your_opt = adamw_optimizer(...)
2068
+ offloaded_opt = offload_optimizer(your_opt)
2069
+ ```
2070
+
2071
+ When using `skip_and_clip_by_global_norm` with this offload optimizer, you must wrap the entire
2072
+ `skip_and_clip_by_global_norm` inside. Do not wrap the inner of `skip_and_clip_by_global_norm`
2073
+ or you will get errors. Correct example:
2074
+ ```
2075
+ offloaded_opt = offload_optimizer(skip_and_clip_by_global_norm(inner=adamw_optimizer(...)))
2076
+ ```
2077
+ The reason is that `skip_and_clip_by_global_norm` conditionally chooses the previous optimizer
2078
+ state and the updated new optimizer state using `jnp.where`, which doesn't support tensors on
2079
+ `pinned_host` memory space.
2080
+ """
2081
+ optimizer = maybe_instantiate (optimizer )
2082
+ if offload_src is None or offload_dst is None :
2083
+ raise ValueError (
2084
+ "offload_src and offload_dst cannot be None when using optimizer offloading."
2085
+ )
2086
+
2087
+ logging .info ("Optimizer offloading from %s to %s enabled." , offload_src , offload_dst )
2088
+
2089
+ def init_fn (params : NestedOptParam ):
2090
+ return optimizer .init (params )
2091
+
2092
+ def _move_fn (state : optax .OptState , dst : MemoryKind ) -> optax .OptState :
2093
+ # TransferToMemoryKind let us change the memory kind of tensors without specifying the full
2094
+ # sharding (i.e. jax.sharding.NamedSharding). Although there's no documentation about it,
2095
+ # it's specified in the API signature. Reference:
2096
+ # https://github.com/jax-ml/jax/blob/21f8885a9e104b8828c9a8b721eed0c68b622691/jax/_src/api.py#L2220
2097
+ # Note: device_put doesn't move everything at once. When we pass a pytree of arrays to
2098
+ # device_put, each array in the pytree is moved independent of one another. The exact order
2099
+ # is decided by the latency hiding scheduler. The scheduler will try to overlap the
2100
+ # transfers of each state with the state update on TPU whenever possible. There is some
2101
+ # memory spike due the the temporary state in HBM, but the spike is much less than the full
2102
+ # memory usage of all states. Moreover, when the optimizer is run, all activations are
2103
+ # released, so we have less memory pressure at that point in time.
2104
+ return jax .tree .map (
2105
+ lambda path , tensor : jax .device_put (tensor , TransferToMemoryKind (dst ))
2106
+ if re .fullmatch (pattern , path )
2107
+ else tensor ,
2108
+ tree_paths (state ),
2109
+ state ,
2110
+ )
2111
+
2112
+ def update_fn (updates : optax .Updates , state : optax .OptState , params : NestedOptParam ):
2113
+ state = _move_fn (state , offload_src )
2114
+ updates , state = optimizer .update (updates , state , params )
2115
+ state = _move_fn (state , offload_dst )
2116
+ return updates , state
2117
+
2118
+ def partition_fn (param_spec : Nested [ParameterSpec ]) -> Nested [OptStateSpec ]:
2119
+ return copy_partition (
2120
+ optimizer .partition (param_spec ), pattern = pattern , memory_kind = offload_dst
2121
+ )
2122
+
2123
+ return PartitionedGradientTransformation (init = init_fn , update = update_fn , partition = partition_fn )
0 commit comments