Skip to content

Commit 795da33

Browse files
authored
Optimizer offloading through weight-only offload (apple#867)
* Optimizer offloading * Style fix * Type fix
1 parent b1a1a5a commit 795da33

File tree

6 files changed

+208
-42
lines changed

6 files changed

+208
-42
lines changed

axlearn/common/factorized_rms_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from axlearn.common import factorized_rms
1313
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
1414
from axlearn.common.optimizer_base import (
15-
NestedOptStateSpec,
15+
Nested,
1616
OptParam,
17+
OptStateSpec,
1718
PartitionedGradientTransformation,
1819
)
19-
from axlearn.common.optimizers import OptStateSpec, with_partition_fn
20+
from axlearn.common.optimizers import with_partition_fn
2021
from axlearn.common.test_utils import TestCase
2122
from axlearn.common.utils import PartitionSpec, flatten_items
2223

@@ -59,7 +60,7 @@ def testParity(self, factored, dtype):
5960

6061
# The 'exp' optimizer is partitioned according to the mesh_axes of parameters and
6162
# factorization spec.
62-
exp_partition: NestedOptStateSpec = exp.partition(param_specs)
63+
exp_partition: Nested[OptStateSpec] = exp.partition(param_specs)
6364
# Used for `count`.
6465
count_spec = OptStateSpec(
6566
dtype=jnp.int32,

axlearn/common/optimizer_base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
- weight_decay_scale: control the weight decay rate.
1717
"""
1818
import dataclasses
19-
from collections.abc import Sequence
2019
from typing import Any, Callable, NamedTuple, Optional, Union
2120

2221
import optax
2322
import typing_extensions
2423

25-
from axlearn.common.base_layer import FactorizationSpec, NestedParameterSpec
26-
from axlearn.common.utils import Tensor, TensorSpec
24+
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
25+
from axlearn.common.utils import Nested, Tensor, TensorSpec
2726

2827

2928
@dataclasses.dataclass
@@ -66,8 +65,7 @@ def __call__(
6665

6766
# Specification of an optimizer state array.
6867
OptStateSpec = TensorSpec
69-
NestedOptStateSpec = Union[OptStateSpec, dict, Sequence]
70-
TransformPartitionSpecFn = Callable[[NestedParameterSpec], NestedOptStateSpec]
68+
TransformPartitionSpecFn = Callable[[Nested[ParameterSpec]], Nested[OptStateSpec]]
7169

7270

7371
class PartitionedGradientTransformation(NamedTuple):

axlearn/common/optimizers.py

Lines changed: 140 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@
3636
import typing_extensions
3737
from absl import logging
3838
from jax import numpy as jnp
39+
from jax._src.sharding_impls import TransferToMemoryKind
3940
from optax._src import numerics
4041

4142
from axlearn.common import schedule, struct
42-
from axlearn.common.base_layer import NestedParameterSpec, ParameterSpec, PartitionSpec
43+
from axlearn.common.base_layer import ParameterSpec, PartitionSpec
4344
from axlearn.common.config import ConfigOr, maybe_instantiate
4445
from axlearn.common.factorized_rms import scale_by_factored_rms
4546
from axlearn.common.module import current_context
@@ -51,8 +52,8 @@
5152
TransformPartitionSpecFn,
5253
)
5354
from axlearn.common.utils import (
55+
MemoryKind,
5456
Nested,
55-
NestedPartitionSpec,
5657
NestedTensor,
5758
NestedTree,
5859
Tensor,
@@ -139,19 +140,41 @@ def update_fn(
139140
return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)
140141

141142

142-
def copy_partition(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
143+
def copy_partition(
144+
specs: Nested[OptStateSpec],
145+
*,
146+
pattern: Union[None, str, re.Pattern] = None,
147+
memory_kind: Optional[MemoryKind] = None,
148+
) -> Nested[OptStateSpec]:
149+
"""Copies OptStateSpec and optionally assigns with a different memory kind.
150+
151+
Args:
152+
specs: Nested[OptStateSpec] to copy from.
153+
pattern: Regex to match the full path of each spec. Matched specs will have their memory
154+
kind replaced with `memory_kind`.
155+
memory_kind: New memory kind. Default to None.
156+
157+
Returns:
158+
A Nested[OptStateSpec] with possibly a different memory kind.
159+
"""
143160
return jax.tree.map(
144-
lambda param_spec: OptStateSpec(
145-
dtype=param_spec.dtype, shape=param_spec.shape, mesh_axes=param_spec.mesh_axes
161+
lambda path, spec: OptStateSpec(
162+
dtype=spec.dtype,
163+
shape=spec.shape,
164+
mesh_axes=spec.mesh_axes,
165+
memory_kind=memory_kind
166+
if pattern and re.fullmatch(pattern, path)
167+
else spec.memory_kind,
146168
),
147-
param_specs,
169+
tree_paths(specs),
170+
specs,
148171
)
149172

150173

151174
def trace_partition(
152175
base: optax.GradientTransformation,
153176
) -> PartitionedGradientTransformation:
154-
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
177+
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
155178
return optax.TraceState(trace=copy_partition(param_specs))
156179

157180
return with_partition_fn(base, partition_fn)
@@ -160,7 +183,9 @@ def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
160183
def adam_partition(base: optax.GradientTransformation) -> PartitionedGradientTransformation:
161184
state: optax.ScaleByAdamState = base.init({})
162185

163-
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
186+
def partition_fn(
187+
param_specs: Nested[ParameterSpec],
188+
) -> Nested[Union[OptStateSpec, optax.ScaleByAdamState]]:
164189
return optax.ScaleByAdamState(
165190
count=OptStateSpec(
166191
dtype=state.count.dtype, shape=state.count.shape, mesh_axes=PartitionSpec()
@@ -950,7 +975,7 @@ def _update(value: Tensor, ema: Tensor, qstep_size: Tensor, count: Tensor) -> _U
950975
)
951976
return updates, new_state
952977

953-
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
978+
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[Union[OptStateSpec, EmaState]]:
954979
def get_ema_partition(param_spec: ParameterSpec) -> OptStateSpec:
955980
# Store momentum in accumulator_dtype if it is set and p is not scalar.
956981
if param_spec.shape and accumulator_dtype is not None:
@@ -1412,7 +1437,9 @@ def _is_valid_step(
14121437
drop_stats=new_drop_stats,
14131438
)
14141439

1415-
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
1440+
def partition_fn(
1441+
param_specs: Nested[ParameterSpec],
1442+
) -> Nested[Union[OptStateSpec, SkipClipState]]:
14161443
if use_adaptive_drop_norm:
14171444
one = jnp.ones([], jnp.float32)
14181445
dict_thresholds = drop_norm(count=one, mean=one, stddev=one)
@@ -1571,7 +1598,9 @@ def update_fn(updates, state, params):
15711598
)
15721599
return updates, ParamEmaState(count=count_inc, ema=new_ema)
15731600

1574-
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
1601+
def partition_fn(
1602+
param_specs: Nested[ParameterSpec],
1603+
) -> Nested[Union[OptStateSpec, ParamEmaState]]:
15751604
return ParamEmaState(
15761605
count=OptStateSpec(dtype=jnp.int32, shape=[], mesh_axes=PartitionSpec()),
15771606
ema=copy_partition(param_specs),
@@ -1617,7 +1646,9 @@ def update_fn(updates, state, params=None):
16171646
updates = jax.tree.map(lambda g, m: jnp.sign((1.0 - b1) * g + b1 * m), updates, state.mu)
16181647
return updates, ScaleByLionState(count=count_inc, mu=mu)
16191648

1620-
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
1649+
def partition_fn(
1650+
param_specs: Nested[ParameterSpec],
1651+
) -> Nested[Union[OptStateSpec, ScaleByLionState]]:
16211652
mu_specs = param_specs
16221653
if mu_dtype is not None:
16231654
mu_specs = jax.tree.map(
@@ -1993,3 +2024,100 @@ def _update2(u: Tensor, param: OptParam):
19932024
partition=lambda _: OptStateSpec(shape=[], dtype=jnp.int32, mesh_axes=PartitionSpec()),
19942025
)
19952026
return named_chain(**tx)
2027+
2028+
2029+
def offload_optimizer(
2030+
optimizer: ConfigOr[PartitionedGradientTransformation],
2031+
*,
2032+
pattern: Union[str, re.Pattern] = ".*",
2033+
offload_src: MemoryKind = "device",
2034+
offload_dst: MemoryKind = "pinned_host",
2035+
) -> PartitionedGradientTransformation:
2036+
"""Offload the state of the wrapped optimizer that matches `pattern` to `offload_dst`.
2037+
2038+
Args:
2039+
optimizer: The optimizer to offload.
2040+
pattern: Regex pattern used to match the path of optimizer states. Fully matched states
2041+
will be offloaded. Default to regex that matches all states.
2042+
offload_src: Offload-from memory kind. Default to "device".
2043+
offload_dst: Offload-to memory kind. Default to "pinned_host".
2044+
2045+
Returns:
2046+
A optimizer whose state is on `offload_dst` and does the same computation as `optimizer`.
2047+
2048+
Raises:
2049+
ValueError: when the `update` function of the returned optimizer is called outside of jit
2050+
context.
2051+
2052+
This function returns a new `PartitionedGradientTransformation` that
2053+
1. Puts matched states of the wrapped optimizer on `offload_dst` through the partition function
2054+
during state initialization in the trainer.
2055+
2. Copies the matched states to `offload_src` before `optimizer.update` is called.
2056+
3. Copies the matched updated states to `offload_dst` after `optimizer.update` is called.
2057+
2058+
The regex pattern is matched against the full path of each optimizer state. An example full
2059+
path is optimizer/1/0/mu/decoder/transformer/repeat/layer/feed_forward/linear1_0. If the
2060+
pattern should not depend on model structure, you can use ".*/mu/.*" to offload all `mu`.
2061+
2062+
The .update function of the returned `PartitionedGradientTransformation` must be called within
2063+
a jit function.
2064+
2065+
Example usage:
2066+
```python
2067+
your_opt = adamw_optimizer(...)
2068+
offloaded_opt = offload_optimizer(your_opt)
2069+
```
2070+
2071+
When using `skip_and_clip_by_global_norm` with this offload optimizer, you must wrap the entire
2072+
`skip_and_clip_by_global_norm` inside. Do not wrap the inner of `skip_and_clip_by_global_norm`
2073+
or you will get errors. Correct example:
2074+
```
2075+
offloaded_opt = offload_optimizer(skip_and_clip_by_global_norm(inner=adamw_optimizer(...)))
2076+
```
2077+
The reason is that `skip_and_clip_by_global_norm` conditionally chooses the previous optimizer
2078+
state and the updated new optimizer state using `jnp.where`, which doesn't support tensors on
2079+
`pinned_host` memory space.
2080+
"""
2081+
optimizer = maybe_instantiate(optimizer)
2082+
if offload_src is None or offload_dst is None:
2083+
raise ValueError(
2084+
"offload_src and offload_dst cannot be None when using optimizer offloading."
2085+
)
2086+
2087+
logging.info("Optimizer offloading from %s to %s enabled.", offload_src, offload_dst)
2088+
2089+
def init_fn(params: NestedOptParam):
2090+
return optimizer.init(params)
2091+
2092+
def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState:
2093+
# TransferToMemoryKind let us change the memory kind of tensors without specifying the full
2094+
# sharding (i.e. jax.sharding.NamedSharding). Although there's no documentation about it,
2095+
# it's specified in the API signature. Reference:
2096+
# https://github.com/jax-ml/jax/blob/21f8885a9e104b8828c9a8b721eed0c68b622691/jax/_src/api.py#L2220
2097+
# Note: device_put doesn't move everything at once. When we pass a pytree of arrays to
2098+
# device_put, each array in the pytree is moved independent of one another. The exact order
2099+
# is decided by the latency hiding scheduler. The scheduler will try to overlap the
2100+
# transfers of each state with the state update on TPU whenever possible. There is some
2101+
# memory spike due the the temporary state in HBM, but the spike is much less than the full
2102+
# memory usage of all states. Moreover, when the optimizer is run, all activations are
2103+
# released, so we have less memory pressure at that point in time.
2104+
return jax.tree.map(
2105+
lambda path, tensor: jax.device_put(tensor, TransferToMemoryKind(dst))
2106+
if re.fullmatch(pattern, path)
2107+
else tensor,
2108+
tree_paths(state),
2109+
state,
2110+
)
2111+
2112+
def update_fn(updates: optax.Updates, state: optax.OptState, params: NestedOptParam):
2113+
state = _move_fn(state, offload_src)
2114+
updates, state = optimizer.update(updates, state, params)
2115+
state = _move_fn(state, offload_dst)
2116+
return updates, state
2117+
2118+
def partition_fn(param_spec: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
2119+
return copy_partition(
2120+
optimizer.partition(param_spec), pattern=pattern, memory_kind=offload_dst
2121+
)
2122+
2123+
return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)

axlearn/common/optimizers_test.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
ema,
4141
l2_regularizer,
4242
lion_optimizer,
43+
offload_optimizer,
4344
opt_param_values,
4445
param_ema,
4546
per_param_scale_by_path,
@@ -379,12 +380,25 @@ def _check_dtypes(x, y, z):
379380
jax.tree.map(_check_dtypes, init_state, partition_state, update_state)
380381

381382
def _test_optimizer(self, optimizer):
382-
params = OptParam(
383-
value=jnp.asarray([0, 1, 2, -3], dtype=jnp.float32),
384-
factorization_spec=None,
385-
weight_decay_scale=1.0,
386-
)
387-
state = optimizer.init(params)
383+
self._test_optimizer_helper(optimizer, True)
384+
self._test_optimizer_helper(optimizer, False)
385+
386+
def _test_optimizer_helper(self, optimizer, offload):
387+
if offload:
388+
optimizer = offload_optimizer(optimizer)
389+
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)
390+
391+
def create_opt_params(x):
392+
return jax.tree.map(
393+
lambda y: OptParam(
394+
value=y,
395+
factorization_spec=None,
396+
weight_decay_scale=1.0,
397+
),
398+
x,
399+
)
400+
401+
state = optimizer.init(create_opt_params(params))
388402

389403
param_spec = ParameterSpec(shape=[4], mesh_axes=PartitionSpec("model"), factorization=None)
390404
state_partition_spec = optimizer.partition(param_spec)
@@ -399,13 +413,23 @@ def check_partition_spec(spec: OptStateSpec, tree):
399413

400414
jax.tree.map(check_partition_spec, state_partition_spec, state)
401415

402-
def compute_loss(x):
403-
return -jax.nn.log_softmax(x)[1]
416+
@jax.jit
417+
def jit_fn(params, state):
418+
def compute_loss(x):
419+
return -jax.nn.log_softmax(x)[1]
404420

405-
loss, grads = jax.value_and_grad(compute_loss)(params.value)
406-
updates, _ = optimizer.update(grads, state=state, params=params)
407-
updated_params = optax.apply_updates(params.value, updates)
408-
new_loss = compute_loss(updated_params)
421+
params = create_opt_params(params)
422+
loss, grads = jax.value_and_grad(compute_loss)(params.value)
423+
updates, _ = optimizer.update(grads, state=state, params=params)
424+
updated_params = optax.apply_updates(params.value, updates)
425+
return loss, compute_loss(updated_params)
426+
427+
if offload:
428+
self.assertIn(
429+
"TransferToMemoryKind(memory_kind='pinned_host')",
430+
str(jax.make_jaxpr(jit_fn)(params, state)),
431+
)
432+
loss, new_loss = jit_fn(params, state)
409433
self.assertLess(new_loss, loss)
410434

411435
@parameterized.product(
@@ -788,14 +812,17 @@ def loss_fn(x):
788812
config_for_function(drop_norm_by_grad_norm_ema).set(multipliers=[0.1, 1]),
789813
config_for_function(drop_norm_by_grad_norm_stddev).set(multipliers=[20, 40]),
790814
),
815+
offload=(True, False),
791816
)
792-
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm):
817+
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm, offload):
793818
clip = skip_and_clip_by_global_norm(
794819
inner=_counter(),
795820
drop_norm=drop_norm,
796821
max_norm=max_norm,
797822
grad_norm_ema_decay=0.99,
798823
)
824+
if offload:
825+
clip = offload_optimizer(clip)
799826
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)
800827
state = clip.init(params)
801828
init_ema = state.grad_norm_ema
@@ -821,7 +848,11 @@ def loss_fn(x):
821848
else:
822849
is_valid_step = drop_norm is None or g_norm < drop_norm
823850

824-
updates, state = clip.update(grads, state=state, params=params)
851+
@jax.jit
852+
def jit_fn(grads, state, params):
853+
return clip.update(grads, state=state, params=params)
854+
855+
updates, state = jit_fn(grads, state, params)
825856
if is_valid_step:
826857
if max_norm is None or g_norm < max_norm:
827858
np.testing.assert_allclose(updates, grads, atol=1e-6)

0 commit comments

Comments
 (0)