Skip to content

Commit

Permalink
Optimizer offloading through weight-only offload (#867)
Browse files Browse the repository at this point in the history
* Optimizer offloading

* Style fix

* Type fix
  • Loading branch information
hanzhi713 authored Jan 30, 2025
1 parent b1a1a5a commit 795da33
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 42 deletions.
7 changes: 4 additions & 3 deletions axlearn/common/factorized_rms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from axlearn.common import factorized_rms
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.optimizer_base import (
NestedOptStateSpec,
Nested,
OptParam,
OptStateSpec,
PartitionedGradientTransformation,
)
from axlearn.common.optimizers import OptStateSpec, with_partition_fn
from axlearn.common.optimizers import with_partition_fn
from axlearn.common.test_utils import TestCase
from axlearn.common.utils import PartitionSpec, flatten_items

Expand Down Expand Up @@ -59,7 +60,7 @@ def testParity(self, factored, dtype):

# The 'exp' optimizer is partitioned according to the mesh_axes of parameters and
# factorization spec.
exp_partition: NestedOptStateSpec = exp.partition(param_specs)
exp_partition: Nested[OptStateSpec] = exp.partition(param_specs)
# Used for `count`.
count_spec = OptStateSpec(
dtype=jnp.int32,
Expand Down
8 changes: 3 additions & 5 deletions axlearn/common/optimizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
- weight_decay_scale: control the weight decay rate.
"""
import dataclasses
from collections.abc import Sequence
from typing import Any, Callable, NamedTuple, Optional, Union

import optax
import typing_extensions

from axlearn.common.base_layer import FactorizationSpec, NestedParameterSpec
from axlearn.common.utils import Tensor, TensorSpec
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.utils import Nested, Tensor, TensorSpec


@dataclasses.dataclass
Expand Down Expand Up @@ -66,8 +65,7 @@ def __call__(

# Specification of an optimizer state array.
OptStateSpec = TensorSpec
NestedOptStateSpec = Union[OptStateSpec, dict, Sequence]
TransformPartitionSpecFn = Callable[[NestedParameterSpec], NestedOptStateSpec]
TransformPartitionSpecFn = Callable[[Nested[ParameterSpec]], Nested[OptStateSpec]]


class PartitionedGradientTransformation(NamedTuple):
Expand Down
152 changes: 140 additions & 12 deletions axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@
import typing_extensions
from absl import logging
from jax import numpy as jnp
from jax._src.sharding_impls import TransferToMemoryKind
from optax._src import numerics

from axlearn.common import schedule, struct
from axlearn.common.base_layer import NestedParameterSpec, ParameterSpec, PartitionSpec
from axlearn.common.base_layer import ParameterSpec, PartitionSpec
from axlearn.common.config import ConfigOr, maybe_instantiate
from axlearn.common.factorized_rms import scale_by_factored_rms
from axlearn.common.module import current_context
Expand All @@ -51,8 +52,8 @@
TransformPartitionSpecFn,
)
from axlearn.common.utils import (
MemoryKind,
Nested,
NestedPartitionSpec,
NestedTensor,
NestedTree,
Tensor,
Expand Down Expand Up @@ -139,19 +140,41 @@ def update_fn(
return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)


def copy_partition(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def copy_partition(
specs: Nested[OptStateSpec],
*,
pattern: Union[None, str, re.Pattern] = None,
memory_kind: Optional[MemoryKind] = None,
) -> Nested[OptStateSpec]:
"""Copies OptStateSpec and optionally assigns with a different memory kind.
Args:
specs: Nested[OptStateSpec] to copy from.
pattern: Regex to match the full path of each spec. Matched specs will have their memory
kind replaced with `memory_kind`.
memory_kind: New memory kind. Default to None.
Returns:
A Nested[OptStateSpec] with possibly a different memory kind.
"""
return jax.tree.map(
lambda param_spec: OptStateSpec(
dtype=param_spec.dtype, shape=param_spec.shape, mesh_axes=param_spec.mesh_axes
lambda path, spec: OptStateSpec(
dtype=spec.dtype,
shape=spec.shape,
mesh_axes=spec.mesh_axes,
memory_kind=memory_kind
if pattern and re.fullmatch(pattern, path)
else spec.memory_kind,
),
param_specs,
tree_paths(specs),
specs,
)


def trace_partition(
base: optax.GradientTransformation,
) -> PartitionedGradientTransformation:
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return optax.TraceState(trace=copy_partition(param_specs))

return with_partition_fn(base, partition_fn)
Expand All @@ -160,7 +183,9 @@ def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def adam_partition(base: optax.GradientTransformation) -> PartitionedGradientTransformation:
state: optax.ScaleByAdamState = base.init({})

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(
param_specs: Nested[ParameterSpec],
) -> Nested[Union[OptStateSpec, optax.ScaleByAdamState]]:
return optax.ScaleByAdamState(
count=OptStateSpec(
dtype=state.count.dtype, shape=state.count.shape, mesh_axes=PartitionSpec()
Expand Down Expand Up @@ -950,7 +975,7 @@ def _update(value: Tensor, ema: Tensor, qstep_size: Tensor, count: Tensor) -> _U
)
return updates, new_state

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[Union[OptStateSpec, EmaState]]:
def get_ema_partition(param_spec: ParameterSpec) -> OptStateSpec:
# Store momentum in accumulator_dtype if it is set and p is not scalar.
if param_spec.shape and accumulator_dtype is not None:
Expand Down Expand Up @@ -1412,7 +1437,9 @@ def _is_valid_step(
drop_stats=new_drop_stats,
)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(
param_specs: Nested[ParameterSpec],
) -> Nested[Union[OptStateSpec, SkipClipState]]:
if use_adaptive_drop_norm:
one = jnp.ones([], jnp.float32)
dict_thresholds = drop_norm(count=one, mean=one, stddev=one)
Expand Down Expand Up @@ -1571,7 +1598,9 @@ def update_fn(updates, state, params):
)
return updates, ParamEmaState(count=count_inc, ema=new_ema)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(
param_specs: Nested[ParameterSpec],
) -> Nested[Union[OptStateSpec, ParamEmaState]]:
return ParamEmaState(
count=OptStateSpec(dtype=jnp.int32, shape=[], mesh_axes=PartitionSpec()),
ema=copy_partition(param_specs),
Expand Down Expand Up @@ -1617,7 +1646,9 @@ def update_fn(updates, state, params=None):
updates = jax.tree.map(lambda g, m: jnp.sign((1.0 - b1) * g + b1 * m), updates, state.mu)
return updates, ScaleByLionState(count=count_inc, mu=mu)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(
param_specs: Nested[ParameterSpec],
) -> Nested[Union[OptStateSpec, ScaleByLionState]]:
mu_specs = param_specs
if mu_dtype is not None:
mu_specs = jax.tree.map(
Expand Down Expand Up @@ -1993,3 +2024,100 @@ def _update2(u: Tensor, param: OptParam):
partition=lambda _: OptStateSpec(shape=[], dtype=jnp.int32, mesh_axes=PartitionSpec()),
)
return named_chain(**tx)


def offload_optimizer(
optimizer: ConfigOr[PartitionedGradientTransformation],
*,
pattern: Union[str, re.Pattern] = ".*",
offload_src: MemoryKind = "device",
offload_dst: MemoryKind = "pinned_host",
) -> PartitionedGradientTransformation:
"""Offload the state of the wrapped optimizer that matches `pattern` to `offload_dst`.
Args:
optimizer: The optimizer to offload.
pattern: Regex pattern used to match the path of optimizer states. Fully matched states
will be offloaded. Default to regex that matches all states.
offload_src: Offload-from memory kind. Default to "device".
offload_dst: Offload-to memory kind. Default to "pinned_host".
Returns:
A optimizer whose state is on `offload_dst` and does the same computation as `optimizer`.
Raises:
ValueError: when the `update` function of the returned optimizer is called outside of jit
context.
This function returns a new `PartitionedGradientTransformation` that
1. Puts matched states of the wrapped optimizer on `offload_dst` through the partition function
during state initialization in the trainer.
2. Copies the matched states to `offload_src` before `optimizer.update` is called.
3. Copies the matched updated states to `offload_dst` after `optimizer.update` is called.
The regex pattern is matched against the full path of each optimizer state. An example full
path is optimizer/1/0/mu/decoder/transformer/repeat/layer/feed_forward/linear1_0. If the
pattern should not depend on model structure, you can use ".*/mu/.*" to offload all `mu`.
The .update function of the returned `PartitionedGradientTransformation` must be called within
a jit function.
Example usage:
```python
your_opt = adamw_optimizer(...)
offloaded_opt = offload_optimizer(your_opt)
```
When using `skip_and_clip_by_global_norm` with this offload optimizer, you must wrap the entire
`skip_and_clip_by_global_norm` inside. Do not wrap the inner of `skip_and_clip_by_global_norm`
or you will get errors. Correct example:
```
offloaded_opt = offload_optimizer(skip_and_clip_by_global_norm(inner=adamw_optimizer(...)))
```
The reason is that `skip_and_clip_by_global_norm` conditionally chooses the previous optimizer
state and the updated new optimizer state using `jnp.where`, which doesn't support tensors on
`pinned_host` memory space.
"""
optimizer = maybe_instantiate(optimizer)
if offload_src is None or offload_dst is None:
raise ValueError(
"offload_src and offload_dst cannot be None when using optimizer offloading."
)

logging.info("Optimizer offloading from %s to %s enabled.", offload_src, offload_dst)

def init_fn(params: NestedOptParam):
return optimizer.init(params)

def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState:
# TransferToMemoryKind let us change the memory kind of tensors without specifying the full
# sharding (i.e. jax.sharding.NamedSharding). Although there's no documentation about it,
# it's specified in the API signature. Reference:
# https://github.com/jax-ml/jax/blob/21f8885a9e104b8828c9a8b721eed0c68b622691/jax/_src/api.py#L2220
# Note: device_put doesn't move everything at once. When we pass a pytree of arrays to
# device_put, each array in the pytree is moved independent of one another. The exact order
# is decided by the latency hiding scheduler. The scheduler will try to overlap the
# transfers of each state with the state update on TPU whenever possible. There is some
# memory spike due the the temporary state in HBM, but the spike is much less than the full
# memory usage of all states. Moreover, when the optimizer is run, all activations are
# released, so we have less memory pressure at that point in time.
return jax.tree.map(
lambda path, tensor: jax.device_put(tensor, TransferToMemoryKind(dst))
if re.fullmatch(pattern, path)
else tensor,
tree_paths(state),
state,
)

def update_fn(updates: optax.Updates, state: optax.OptState, params: NestedOptParam):
state = _move_fn(state, offload_src)
updates, state = optimizer.update(updates, state, params)
state = _move_fn(state, offload_dst)
return updates, state

def partition_fn(param_spec: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return copy_partition(
optimizer.partition(param_spec), pattern=pattern, memory_kind=offload_dst
)

return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)
59 changes: 45 additions & 14 deletions axlearn/common/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ema,
l2_regularizer,
lion_optimizer,
offload_optimizer,
opt_param_values,
param_ema,
per_param_scale_by_path,
Expand Down Expand Up @@ -379,12 +380,25 @@ def _check_dtypes(x, y, z):
jax.tree.map(_check_dtypes, init_state, partition_state, update_state)

def _test_optimizer(self, optimizer):
params = OptParam(
value=jnp.asarray([0, 1, 2, -3], dtype=jnp.float32),
factorization_spec=None,
weight_decay_scale=1.0,
)
state = optimizer.init(params)
self._test_optimizer_helper(optimizer, True)
self._test_optimizer_helper(optimizer, False)

def _test_optimizer_helper(self, optimizer, offload):
if offload:
optimizer = offload_optimizer(optimizer)
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)

def create_opt_params(x):
return jax.tree.map(
lambda y: OptParam(
value=y,
factorization_spec=None,
weight_decay_scale=1.0,
),
x,
)

state = optimizer.init(create_opt_params(params))

param_spec = ParameterSpec(shape=[4], mesh_axes=PartitionSpec("model"), factorization=None)
state_partition_spec = optimizer.partition(param_spec)
Expand All @@ -399,13 +413,23 @@ def check_partition_spec(spec: OptStateSpec, tree):

jax.tree.map(check_partition_spec, state_partition_spec, state)

def compute_loss(x):
return -jax.nn.log_softmax(x)[1]
@jax.jit
def jit_fn(params, state):
def compute_loss(x):
return -jax.nn.log_softmax(x)[1]

loss, grads = jax.value_and_grad(compute_loss)(params.value)
updates, _ = optimizer.update(grads, state=state, params=params)
updated_params = optax.apply_updates(params.value, updates)
new_loss = compute_loss(updated_params)
params = create_opt_params(params)
loss, grads = jax.value_and_grad(compute_loss)(params.value)
updates, _ = optimizer.update(grads, state=state, params=params)
updated_params = optax.apply_updates(params.value, updates)
return loss, compute_loss(updated_params)

if offload:
self.assertIn(
"TransferToMemoryKind(memory_kind='pinned_host')",
str(jax.make_jaxpr(jit_fn)(params, state)),
)
loss, new_loss = jit_fn(params, state)
self.assertLess(new_loss, loss)

@parameterized.product(
Expand Down Expand Up @@ -788,14 +812,17 @@ def loss_fn(x):
config_for_function(drop_norm_by_grad_norm_ema).set(multipliers=[0.1, 1]),
config_for_function(drop_norm_by_grad_norm_stddev).set(multipliers=[20, 40]),
),
offload=(True, False),
)
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm):
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm, offload):
clip = skip_and_clip_by_global_norm(
inner=_counter(),
drop_norm=drop_norm,
max_norm=max_norm,
grad_norm_ema_decay=0.99,
)
if offload:
clip = offload_optimizer(clip)
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)
state = clip.init(params)
init_ema = state.grad_norm_ema
Expand All @@ -821,7 +848,11 @@ def loss_fn(x):
else:
is_valid_step = drop_norm is None or g_norm < drop_norm

updates, state = clip.update(grads, state=state, params=params)
@jax.jit
def jit_fn(grads, state, params):
return clip.update(grads, state=state, params=params)

updates, state = jit_fn(grads, state, params)
if is_valid_step:
if max_norm is None or g_norm < max_norm:
np.testing.assert_allclose(updates, grads, atol=1e-6)
Expand Down
Loading

0 comments on commit 795da33

Please sign in to comment.