Skip to content

Commit 6f67efb

Browse files
authored
Add deepspeed support (ai-safety-foundation#186)
1 parent e10d2a2 commit 6f67efb

File tree

16 files changed

+656
-327
lines changed

16 files changed

+656
-327
lines changed

.vscode/cspell.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"jaxtyping",
5252
"kaiming",
5353
"keepdim",
54+
"logit",
5455
"lognormal",
5556
"loguniform",
5657
"loguniformvalues",
@@ -73,6 +74,7 @@
7374
"neox",
7475
"nonlinerity",
7576
"numel",
77+
"onebit",
7678
"openwebtext",
7779
"optim",
7880
"penality",
@@ -116,6 +118,7 @@
116118
"venv",
117119
"virtualenv",
118120
"virtualenvs",
119-
"wandb"
121+
"wandb",
122+
"zoadam"
120123
]
121124
}

poetry.lock

Lines changed: 320 additions & 195 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@
77
readme="README.md"
88
version="0.0.0"
99

10+
# Note: Zstandard is required for downloading datasets such as The Pile
1011
[tool.poetry.dependencies]
1112
datasets=">=2.15.0"
13+
deepspeed={version=">=0.12.6", extras=["deepspeed"], optional=false}
1214
einops=">=0.6"
15+
mpi4py={version=">=3.1.5", extras=["deepspeed"], optional=true}
1316
pydantic=">=2.5.2"
1417
python=">=3.10, <3.12"
1518
strenum=">=0.4.15"
1619
tokenizers=">=0.15.0"
1720
torch=">=2.1.1"
1821
transformers=">=4.35.2"
1922
wandb=">=0.16.1"
20-
zstandard=">=0.22.0" # Required for downloading datasets such as The Pile
23+
zstandard=">=0.22.0"
2124

2225
[tool.poetry.group]
2326
[tool.poetry.group.dev.dependencies]
@@ -54,6 +57,9 @@
5457
pymdown-extensions=">=10.5"
5558
pytkdocs-tweaks=">=0.0.7"
5659

60+
[tool.poetry.extras]
61+
deepspeed=["deepspeed", "mpi4py"]
62+
5763
[tool.poetry.scripts]
5864
join-sae-sweep='sparse_autoencoder.train.join_sweep:run'
5965

sparse_autoencoder/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sparse_autoencoder.metrics.train.capacity import CapacityMetric
1010
from sparse_autoencoder.metrics.train.feature_density import TrainBatchFeatureDensityMetric
1111
from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset
12+
from sparse_autoencoder.optimizer.deepspeed_adam_with_reset import ZeroOneAdamWithReset
1213
from sparse_autoencoder.source_data.pretokenized_dataset import PreTokenizedDataset
1314
from sparse_autoencoder.source_data.text_dataset import TextDataset
1415
from sparse_autoencoder.train.pipeline import Pipeline
@@ -83,4 +84,5 @@
8384
"TensorActivationStore",
8485
"TextDataset",
8586
"TrainBatchFeatureDensityMetric",
87+
"ZeroOneAdamWithReset",
8688
]

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from dataclasses import dataclass
33
from typing import Annotated, NamedTuple
44

5+
from deepspeed import DeepSpeedEngine
56
from einops import rearrange
67
from jaxtyping import Bool, Float, Int64
78
from pydantic import Field, NonNegativeInt, PositiveInt, validate_call
89
import torch
910
from torch import Tensor
11+
from torch.nn.parallel import DataParallel
1012
from torch.utils.data import DataLoader
1113

1214
from sparse_autoencoder.activation_resampler.utils.component_slice_tensor import (
@@ -17,7 +19,6 @@
1719
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
1820
from sparse_autoencoder.tensor_types import Axis
1921
from sparse_autoencoder.train.utils.get_model_device import get_model_device
20-
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes
2122

2223

2324
@dataclass
@@ -207,7 +208,7 @@ def _get_dead_neuron_indices(
207208
def compute_loss_and_get_activations(
208209
self,
209210
store: ActivationStore,
210-
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
211+
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
211212
loss_fn: AbstractLoss,
212213
train_batch_size: int,
213214
) -> LossInputActivationsTuple:
@@ -440,7 +441,7 @@ def renormalize_and_scale(
440441
def resample_dead_neurons(
441442
self,
442443
activation_store: ActivationStore,
443-
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
444+
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
444445
loss_fn: AbstractLoss,
445446
train_batch_size: int,
446447
) -> list[ParameterUpdateResults]:
@@ -530,7 +531,7 @@ def step_resampler(
530531
self,
531532
batch_neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)],
532533
activation_store: ActivationStore,
533-
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
534+
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
534535
loss_fn: AbstractLoss,
535536
train_batch_size: int,
536537
) -> list[ParameterUpdateResults] | None:

sparse_autoencoder/autoencoder/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def load(
276276
The loaded model.
277277
"""
278278
# Load the file
279-
serialized_state = torch.load(file_path)
279+
serialized_state = torch.load(file_path, map_location=torch.device("cpu"))
280280
state = SparseAutoencoderState.model_validate(serialized_state)
281281

282282
# Initialise the model

sparse_autoencoder/optimizer/adam_with_reset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__( # (extending existing implementation)
4242
lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3,
4343
betas: tuple[float, float] = (0.9, 0.999),
4444
eps: float = 1e-8,
45-
weight_decay: float = 0,
45+
weight_decay: float = 0.0,
4646
*,
4747
amsgrad: bool = False,
4848
foreach: bool | None = None,
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""Deepspeed Zero One Adam Optimizer with a reset method.
2+
3+
This reset method is useful when resampling dead neurons during training.
4+
"""
5+
from collections.abc import Iterator
6+
from typing import final
7+
8+
from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
9+
from jaxtyping import Int
10+
from torch import Tensor
11+
from torch.nn.parameter import Parameter
12+
from torch.optim.optimizer import params_t
13+
14+
from sparse_autoencoder.optimizer.abstract_optimizer import AbstractOptimizerWithReset
15+
from sparse_autoencoder.tensor_types import Axis
16+
17+
18+
@final
19+
class ZeroOneAdamWithReset(ZeroOneAdam, AbstractOptimizerWithReset):
20+
"""Deepspeed Zero One Adam Optimizer with a reset method.
21+
22+
https://deepspeed.readthedocs.io/en/latest/optimizers.html#zerooneadam-gpu
23+
24+
The :meth:`reset_state_all_parameters` and :meth:`reset_neurons_state` methods are useful when
25+
manually editing the model parameters during training (e.g. when resampling dead neurons). This
26+
is because Adam maintains running averages of the gradients and the squares of gradients, which
27+
will be incorrect if the parameters are changed.
28+
29+
Otherwise this is the same as the standard ZeroOneAdam optimizer.
30+
31+
Warning:
32+
Requires a distributed torch backend.
33+
"""
34+
35+
parameter_names: list[str]
36+
"""Parameter Names.
37+
38+
The names of the parameters, so that we can find them later when resetting the state.
39+
"""
40+
41+
_has_components_dim: bool
42+
"""Whether the parameters have a components dimension."""
43+
44+
def __init__(
45+
self,
46+
params: params_t,
47+
lr: float = 1e-3,
48+
betas: tuple[float, float] = (0.9, 0.999),
49+
eps: float = 1e-8,
50+
weight_decay: float = 0.0,
51+
*,
52+
named_parameters: Iterator[tuple[str, Parameter]],
53+
has_components_dim: bool,
54+
) -> None:
55+
"""Initialize the optimizer.
56+
57+
Warning:
58+
Named parameters must be with default settings (remove duplicates and not recursive).
59+
60+
Args:
61+
params: Iterable of parameters to optimize or dicts defining parameter groups.
62+
lr: Learning rate. A Tensor LR is not yet fully supported for all implementations. Use a
63+
float LR unless specifying fused=True or capturable=True.
64+
betas: Coefficients used for computing running averages of gradient and its square.
65+
eps: Term added to the denominator to improve numerical stability.
66+
weight_decay: Weight decay (L2 penalty).
67+
named_parameters: An iterator over the named parameters of the model. This is used to
68+
find the parameters when resetting their state. You should set this as
69+
`model.named_parameters()`.
70+
has_components_dim: If the parameters have a components dimension (i.e. if you are
71+
training an SAE on more than one component).
72+
73+
74+
Raises:
75+
ValueError: If the number of parameter names does not match the number of parameters.
76+
"""
77+
# Initialise the parent class (note we repeat the parameter names so that type hints work).
78+
super().__init__(
79+
params=params,
80+
lr=lr,
81+
betas=betas,
82+
eps=eps,
83+
weight_decay=weight_decay,
84+
)
85+
86+
self._has_components_dim = has_components_dim
87+
88+
# Store the names of the parameters, so that we can find them later when resetting the
89+
# state.
90+
self.parameter_names = [name for name, _value in named_parameters]
91+
92+
if len(self.parameter_names) != len(self.param_groups[0]["params"]):
93+
error_message = (
94+
"The number of parameter names does not match the number of parameters. "
95+
"If using model.named_parameters() make sure remove_duplicates is True "
96+
"and recursive is False (the default settings)."
97+
)
98+
raise ValueError(error_message)
99+
100+
def reset_state_all_parameters(self) -> None:
101+
"""Reset the state for all parameters.
102+
103+
Iterates over all parameters and resets both the running averages of the gradients and the
104+
squares of gradients.
105+
"""
106+
# Iterate over every parameter
107+
for group in self.param_groups:
108+
for parameter in group["params"]:
109+
# Get the state
110+
state = self.state[parameter]
111+
112+
# Check if state is initialized
113+
if len(state) == 0:
114+
continue
115+
116+
# Reset running averages
117+
exp_avg: Tensor = state["exp_avg"]
118+
exp_avg.zero_()
119+
exp_avg_sq: Tensor = state["exp_avg_sq"]
120+
exp_avg_sq.zero_()
121+
122+
# If AdamW is used (weight decay fix), also reset the max exp_avg_sq
123+
if "max_exp_avg_sq" in state:
124+
max_exp_avg_sq: Tensor = state["max_exp_avg_sq"]
125+
max_exp_avg_sq.zero_()
126+
127+
def reset_neurons_state(
128+
self,
129+
parameter: Parameter,
130+
neuron_indices: Int[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)],
131+
axis: int,
132+
component_idx: int = 0,
133+
) -> None:
134+
"""Reset the state for specific neurons, on a specific parameter.
135+
136+
Args:
137+
parameter: The parameter to be reset. Examples from the standard sparse autoencoder
138+
implementation include `tied_bias`, `_encoder._weight`, `_encoder._bias`,
139+
neuron_indices: The indices of the neurons to reset.
140+
axis: The axis of the state values to reset (i.e. the input/output features axis, as
141+
we're resetting all input/output features for a specific dead neuron).
142+
component_idx: The component index of the state values to reset.
143+
144+
Raises:
145+
ValueError: If the parameter has a components dimension, but has_components_dim is
146+
False.
147+
"""
148+
# Get the state of the parameter
149+
state = self.state[parameter]
150+
151+
# If the number of dimensions is 3, we definitely have a components dimension. If 2, we may
152+
# do (as the bias has 2 dimensions with components, but the weight has 2 dimensions without
153+
# components).
154+
definitely_has_components_dimension = 3
155+
if (
156+
not self._has_components_dim
157+
and state["exp_avg"].ndim == definitely_has_components_dimension
158+
):
159+
error_message = (
160+
"The parameter has a components dimension, but has_components_dim is False. "
161+
"This should not happen."
162+
)
163+
raise ValueError(error_message)
164+
165+
# Check if state is initialized
166+
if len(state) == 0:
167+
return
168+
169+
# Check there are any neurons to reset
170+
if neuron_indices.numel() == 0:
171+
return
172+
173+
# Move the neuron indices to the correct device
174+
neuron_indices = neuron_indices.to(device=state["exp_avg"].device)
175+
176+
# Reset running averages for the specified neurons
177+
if "exp_avg" in state:
178+
if self._has_components_dim:
179+
state["exp_avg"][component_idx].index_fill_(axis, neuron_indices, 0)
180+
else:
181+
state["exp_avg"].index_fill_(axis, neuron_indices, 0)
182+
183+
if "exp_avg_sq" in state:
184+
if self._has_components_dim:
185+
state["exp_avg_sq"][component_idx].index_fill_(axis, neuron_indices, 0)
186+
else:
187+
state["exp_avg_sq"].index_fill_(axis, neuron_indices, 0)
188+
189+
# If AdamW is used (weight decay fix), also reset the max exp_avg_sq
190+
if "max_exp_avg_sq" in state:
191+
if self._has_components_dim:
192+
state["max_exp_avg_sq"][component_idx].index_fill_(axis, neuron_indices, 0)
193+
else:
194+
state["max_exp_avg_sq"].index_fill_(axis, neuron_indices, 0)

sparse_autoencoder/source_model/replace_activations_hook.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,37 @@
11
"""Replace activations hook."""
22
from typing import TYPE_CHECKING
33

4+
from deepspeed import DeepSpeedEngine
5+
from jaxtyping import Float
46
from torch import Tensor
7+
from torch.nn.parallel import DataParallel
58
from transformer_lens.hook_points import HookPoint
69

710
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
8-
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes
911

1012

1113
if TYPE_CHECKING:
1214
from sparse_autoencoder.tensor_types import Axis
13-
from jaxtyping import Float
1415

1516

1617
def replace_activations_hook(
1718
value: Tensor,
1819
hook: HookPoint, # noqa: ARG001
19-
sparse_autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
20+
sparse_autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
2021
component_idx: int | None = None,
22+
n_components: int | None = None,
2123
) -> Tensor:
2224
"""Replace activations hook.
2325
26+
This should be pre-initialised with `functools.partial`.
27+
2428
Args:
2529
value: The activations to replace.
2630
hook: The hook point.
27-
sparse_autoencoder: The sparse autoencoder. This should be pre-initialised with
28-
`functools.partial`.
31+
sparse_autoencoder: The sparse autoencoder.
2932
component_idx: The component index to replace the activations with, if just replacing
3033
activations for a single component. Requires the model to have a component axis.
34+
n_components: The number of components that the SAE is trained on.
3135
3236
Returns:
3337
Replaced activations.
@@ -43,11 +47,8 @@ def replace_activations_hook(
4347
)
4448

4549
if component_idx is not None:
46-
if sparse_autoencoder.config.n_components is None:
47-
error_message = (
48-
"Cannot replace for a specific component, if the model does not have a "
49-
"component axis."
50-
)
50+
if n_components is None:
51+
error_message = "The number of model components must be set if component_idx is set."
5152
raise RuntimeError(error_message)
5253

5354
# The approach here is to run a forward pass with dummy values for all components other than
@@ -56,7 +57,7 @@ def replace_activations_hook(
5657
# components.
5758
expanded_shape = [
5859
squashed_value.shape[0],
59-
sparse_autoencoder.config.n_components,
60+
n_components,
6061
squashed_value.shape[-1],
6162
]
6263
expanded = squashed_value.unsqueeze(1).expand(*expanded_shape)

0 commit comments

Comments
 (0)