Skip to content

Commit 6b9c2a5

Browse files
authored
Add join run cli command (#177)
1 parent b8cc0cc commit 6b9c2a5

File tree

10 files changed

+114
-513
lines changed

10 files changed

+114
-513
lines changed

docs/content/SUMMARY.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11

22
* [Home](index.md)
33
* [Demo](demo.ipynb)
4-
* [Flexible demo](flexible_demo.ipynb)
54
* [Source dataset pre-processing](pre-process-datasets.ipynb)
65
* [Reference](reference/)
76
* [Contributing](contributing.md)

docs/content/demo.ipynb

Lines changed: 88 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
"cell_type": "markdown",
1414
"metadata": {},
1515
"source": [
16-
"# Quick Start Training Demo\n",
16+
"# Training Demo\n",
1717
"\n",
1818
"This is a quick start demo to get training a SAE right away. All you need to do is choose a few\n",
19-
"hyperparameters (like the model to train on), and then set it off. By default it trains SAEs on all\n",
20-
"MLP layers from GPT2 small."
19+
"hyperparameters (like the model to train on), and then set it off.\n",
20+
"\n",
21+
"In this demo we'll train a sparse autoencoder on all MLP layer outputs in GPT-2 small (effectively\n",
22+
"training an SAE on each layer in parallel)."
2123
]
2224
},
2325
{
@@ -68,6 +70,7 @@
6870
"\n",
6971
"from sparse_autoencoder import (\n",
7072
" ActivationResamplerHyperparameters,\n",
73+
" AutoencoderHyperparameters,\n",
7174
" Hyperparameters,\n",
7275
" LossHyperparameters,\n",
7376
" Method,\n",
@@ -76,12 +79,10 @@
7679
" PipelineHyperparameters,\n",
7780
" SourceDataHyperparameters,\n",
7881
" SourceModelHyperparameters,\n",
79-
" sweep,\n",
8082
" SweepConfig,\n",
83+
" sweep,\n",
8184
")\n",
8285
"\n",
83-
"\n",
84-
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
8586
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"demo.ipynb\""
8687
]
8788
},
@@ -97,76 +98,73 @@
9798
"metadata": {},
9899
"source": [
99100
"Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and\n",
100-
"learning rate):"
101+
"learning rate).\n",
102+
"\n",
103+
"Note we are using the RANDOM sweep approach (try random combinations of hyperparameters), which\n",
104+
"works surprisingly well but will need to be stopped at some point (as otherwise it will continue\n",
105+
"forever). If you want to run pre-defined runs consider using `Parameter(values=[0.01, 0.05...])` for\n",
106+
"example rather than `Parameter(max=0.03, min=0.008)` for each parameter you are sweeping over. You\n",
107+
"can then set the strategy to `Method.GRID`."
101108
]
102109
},
103110
{
104111
"cell_type": "code",
105-
"execution_count": 5,
112+
"execution_count": 3,
106113
"metadata": {},
107-
"outputs": [
108-
{
109-
"data": {
110-
"text/plain": [
111-
"SweepConfig(parameters=Hyperparameters(\n",
112-
" source_data=SourceDataHyperparameters(dataset_path=Parameter(value=alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2), context_size=Parameter(value=128), dataset_column_name=Parameter(value=input_ids), dataset_dir=None, dataset_files=None, pre_download=Parameter(value=False), pre_tokenized=Parameter(value=True), tokenizer_name=None)\n",
113-
" source_model=SourceModelHyperparameters(name=Parameter(value=gpt2-small), cache_names=Parameter(value=['blocks.0.hook_mlp_out', 'blocks.1.hook_mlp_out', 'blocks.2.hook_mlp_out', 'blocks.3.hook_mlp_out', 'blocks.4.hook_mlp_out', 'blocks.5.hook_mlp_out', 'blocks.6.hook_mlp_out', 'blocks.7.hook_mlp_out', 'blocks.8.hook_mlp_out', 'blocks.9.hook_mlp_out', 'blocks.10.hook_mlp_out', 'blocks.11.hook_mlp_out']), hook_dimension=Parameter(value=768), dtype=Parameter(value=float32))\n",
114-
" activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_n_resamples=Parameter(value=4), n_activations_activity_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=200000), threshold_is_dead_portion_fires=Parameter(value=1e-06))\n",
115-
" autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=2))\n",
116-
" loss=LossHyperparameters(l1_coefficient=Parameter(max=0.01, min=0.004))\n",
117-
" optimizer=OptimizerHyperparameters(lr=Parameter(max=0.001, min=1e-05), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n",
118-
" pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=16), train_batch_size=Parameter(value=1024), max_store_size=Parameter(value=300000), max_activations=Parameter(value=1000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=100000000), validation_n_activations=Parameter(value=8192))\n",
119-
" random_seed=Parameter(value=49)\n",
120-
"), method=<Method.RANDOM: 'random'>, metric=Metric(name=train/loss/total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None)"
121-
]
122-
},
123-
"execution_count": 5,
124-
"metadata": {},
125-
"output_type": "execute_result"
126-
}
127-
],
114+
"outputs": [],
128115
"source": [
129-
"n_layers_gpt2_small = 12\n",
116+
"def train_gpt_small_mlp_layers(\n",
117+
" expansion_factor: int = 4,\n",
118+
" n_layers: int = 12,\n",
119+
") -> None:\n",
120+
" \"\"\"Run a new sweep experiment on GPT 2 Small's MLP layers.\n",
130121
"\n",
131-
"sweep_config = SweepConfig(\n",
132-
" parameters=Hyperparameters(\n",
133-
" activation_resampler=ActivationResamplerHyperparameters(\n",
134-
" resample_interval=Parameter(200_000_000),\n",
135-
" n_activations_activity_collate=Parameter(100_000_000),\n",
136-
" threshold_is_dead_portion_fires=Parameter(1e-6),\n",
137-
" max_n_resamples=Parameter(4),\n",
138-
" resample_dataset_size=Parameter(200_000),\n",
139-
" ),\n",
140-
" loss=LossHyperparameters(\n",
141-
" l1_coefficient=Parameter(max=1e-2, min=4e-3),\n",
142-
" ),\n",
143-
" optimizer=OptimizerHyperparameters(\n",
144-
" lr=Parameter(max=1e-3, min=1e-5),\n",
145-
" ),\n",
146-
" source_model=SourceModelHyperparameters(\n",
147-
" name=Parameter(\"gpt2-small\"),\n",
148-
" # Train in parallel on all MLP layers\n",
149-
" cache_names=Parameter(\n",
150-
" [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers_gpt2_small)]\n",
122+
" Args:\n",
123+
" expansion_factor: Expansion factor for the autoencoder.\n",
124+
" n_layers: Number of layers to train on. Max is 12.\n",
125+
"\n",
126+
" \"\"\"\n",
127+
" sweep_config = SweepConfig(\n",
128+
" parameters=Hyperparameters(\n",
129+
" loss=LossHyperparameters(\n",
130+
" l1_coefficient=Parameter(max=0.03, min=0.008),\n",
131+
" ),\n",
132+
" optimizer=OptimizerHyperparameters(\n",
133+
" lr=Parameter(max=0.001, min=0.00001),\n",
134+
" ),\n",
135+
" source_model=SourceModelHyperparameters(\n",
136+
" name=Parameter(\"gpt2\"),\n",
137+
" cache_names=Parameter(\n",
138+
" [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers)]\n",
139+
" ),\n",
140+
" hook_dimension=Parameter(768),\n",
141+
" ),\n",
142+
" source_data=SourceDataHyperparameters(\n",
143+
" dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n",
144+
" context_size=Parameter(256),\n",
145+
" pre_tokenized=Parameter(value=True),\n",
146+
" pre_download=Parameter(value=False), # Default to streaming the dataset\n",
147+
" ),\n",
148+
" autoencoder=AutoencoderHyperparameters(\n",
149+
" expansion_factor=Parameter(value=expansion_factor)\n",
150+
" ),\n",
151+
" pipeline=PipelineHyperparameters(\n",
152+
" max_activations=Parameter(1_000_000_000),\n",
153+
" checkpoint_frequency=Parameter(100_000_000),\n",
154+
" validation_frequency=Parameter(100_000_000),\n",
155+
" max_store_size=Parameter(1_000_000),\n",
156+
" ),\n",
157+
" activation_resampler=ActivationResamplerHyperparameters(\n",
158+
" resample_interval=Parameter(200_000_000),\n",
159+
" n_activations_activity_collate=Parameter(100_000_000),\n",
160+
" threshold_is_dead_portion_fires=Parameter(1e-6),\n",
161+
" max_n_resamples=Parameter(4),\n",
151162
" ),\n",
152-
" hook_dimension=Parameter(768),\n",
153-
" ),\n",
154-
" source_data=SourceDataHyperparameters(\n",
155-
" dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n",
156-
" context_size=Parameter(128),\n",
157-
" pre_tokenized=Parameter(value=True),\n",
158-
" ),\n",
159-
" pipeline=PipelineHyperparameters(\n",
160-
" max_activations=Parameter(1_000_000_000),\n",
161-
" checkpoint_frequency=Parameter(100_000_000),\n",
162-
" validation_frequency=Parameter(100_000_000),\n",
163-
" train_batch_size=Parameter(1024),\n",
164-
" max_store_size=Parameter(300_000),\n",
165163
" ),\n",
166-
" ),\n",
167-
" method=Method.RANDOM,\n",
168-
")\n",
169-
"sweep_config"
164+
" method=Method.RANDOM,\n",
165+
" )\n",
166+
"\n",
167+
" sweep(sweep_config=sweep_config)"
170168
]
171169
},
172170
{
@@ -176,13 +174,34 @@
176174
"### Run the sweep"
177175
]
178176
},
177+
{
178+
"cell_type": "markdown",
179+
"metadata": {},
180+
"source": [
181+
"This will start a sweep with just one agent (the current machine). If you have multiple GPUs, it\n",
182+
"will use them automatically. Similarly it will work on Apple silicon devices by automatically using MPS."
183+
]
184+
},
179185
{
180186
"cell_type": "code",
181-
"execution_count": null,
187+
"execution_count": 4,
182188
"metadata": {},
183189
"outputs": [],
184190
"source": [
185-
"sweep(sweep_config=sweep_config)"
191+
"train_gpt_small_mlp_layers()"
192+
]
193+
},
194+
{
195+
"cell_type": "markdown",
196+
"metadata": {},
197+
"source": [
198+
"Want to speed things up? You can trivially add extra machines to the sweep, each of which will peel\n",
199+
"of some runs from the sweep agent (stored on Wandb). To do this, on another machine simply run:\n",
200+
"\n",
201+
"```bash\n",
202+
"pip install sparse_autoencoder\n",
203+
"join-sae-sweep --id=SWEEP_ID_SHOWN_ON_WANDB\n",
204+
"```"
186205
]
187206
}
188207
],

0 commit comments

Comments
 (0)