|
13 | 13 | "cell_type": "markdown",
|
14 | 14 | "metadata": {},
|
15 | 15 | "source": [
|
16 |
| - "# Quick Start Training Demo\n", |
| 16 | + "# Training Demo\n", |
17 | 17 | "\n",
|
18 | 18 | "This is a quick start demo to get training a SAE right away. All you need to do is choose a few\n",
|
19 |
| - "hyperparameters (like the model to train on), and then set it off. By default it trains SAEs on all\n", |
20 |
| - "MLP layers from GPT2 small." |
| 19 | + "hyperparameters (like the model to train on), and then set it off.\n", |
| 20 | + "\n", |
| 21 | + "In this demo we'll train a sparse autoencoder on all MLP layer outputs in GPT-2 small (effectively\n", |
| 22 | + "training an SAE on each layer in parallel)." |
21 | 23 | ]
|
22 | 24 | },
|
23 | 25 | {
|
|
68 | 70 | "\n",
|
69 | 71 | "from sparse_autoencoder import (\n",
|
70 | 72 | " ActivationResamplerHyperparameters,\n",
|
| 73 | + " AutoencoderHyperparameters,\n", |
71 | 74 | " Hyperparameters,\n",
|
72 | 75 | " LossHyperparameters,\n",
|
73 | 76 | " Method,\n",
|
|
76 | 79 | " PipelineHyperparameters,\n",
|
77 | 80 | " SourceDataHyperparameters,\n",
|
78 | 81 | " SourceModelHyperparameters,\n",
|
79 |
| - " sweep,\n", |
80 | 82 | " SweepConfig,\n",
|
| 83 | + " sweep,\n", |
81 | 84 | ")\n",
|
82 | 85 | "\n",
|
83 |
| - "\n", |
84 |
| - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", |
85 | 86 | "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"demo.ipynb\""
|
86 | 87 | ]
|
87 | 88 | },
|
|
97 | 98 | "metadata": {},
|
98 | 99 | "source": [
|
99 | 100 | "Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and\n",
|
100 |
| - "learning rate):" |
| 101 | + "learning rate).\n", |
| 102 | + "\n", |
| 103 | + "Note we are using the RANDOM sweep approach (try random combinations of hyperparameters), which\n", |
| 104 | + "works surprisingly well but will need to be stopped at some point (as otherwise it will continue\n", |
| 105 | + "forever). If you want to run pre-defined runs consider using `Parameter(values=[0.01, 0.05...])` for\n", |
| 106 | + "example rather than `Parameter(max=0.03, min=0.008)` for each parameter you are sweeping over. You\n", |
| 107 | + "can then set the strategy to `Method.GRID`." |
101 | 108 | ]
|
102 | 109 | },
|
103 | 110 | {
|
104 | 111 | "cell_type": "code",
|
105 |
| - "execution_count": 5, |
| 112 | + "execution_count": 3, |
106 | 113 | "metadata": {},
|
107 |
| - "outputs": [ |
108 |
| - { |
109 |
| - "data": { |
110 |
| - "text/plain": [ |
111 |
| - "SweepConfig(parameters=Hyperparameters(\n", |
112 |
| - " source_data=SourceDataHyperparameters(dataset_path=Parameter(value=alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2), context_size=Parameter(value=128), dataset_column_name=Parameter(value=input_ids), dataset_dir=None, dataset_files=None, pre_download=Parameter(value=False), pre_tokenized=Parameter(value=True), tokenizer_name=None)\n", |
113 |
| - " source_model=SourceModelHyperparameters(name=Parameter(value=gpt2-small), cache_names=Parameter(value=['blocks.0.hook_mlp_out', 'blocks.1.hook_mlp_out', 'blocks.2.hook_mlp_out', 'blocks.3.hook_mlp_out', 'blocks.4.hook_mlp_out', 'blocks.5.hook_mlp_out', 'blocks.6.hook_mlp_out', 'blocks.7.hook_mlp_out', 'blocks.8.hook_mlp_out', 'blocks.9.hook_mlp_out', 'blocks.10.hook_mlp_out', 'blocks.11.hook_mlp_out']), hook_dimension=Parameter(value=768), dtype=Parameter(value=float32))\n", |
114 |
| - " activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_n_resamples=Parameter(value=4), n_activations_activity_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=200000), threshold_is_dead_portion_fires=Parameter(value=1e-06))\n", |
115 |
| - " autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=2))\n", |
116 |
| - " loss=LossHyperparameters(l1_coefficient=Parameter(max=0.01, min=0.004))\n", |
117 |
| - " optimizer=OptimizerHyperparameters(lr=Parameter(max=0.001, min=1e-05), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n", |
118 |
| - " pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=16), train_batch_size=Parameter(value=1024), max_store_size=Parameter(value=300000), max_activations=Parameter(value=1000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=100000000), validation_n_activations=Parameter(value=8192))\n", |
119 |
| - " random_seed=Parameter(value=49)\n", |
120 |
| - "), method=<Method.RANDOM: 'random'>, metric=Metric(name=train/loss/total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None)" |
121 |
| - ] |
122 |
| - }, |
123 |
| - "execution_count": 5, |
124 |
| - "metadata": {}, |
125 |
| - "output_type": "execute_result" |
126 |
| - } |
127 |
| - ], |
| 114 | + "outputs": [], |
128 | 115 | "source": [
|
129 |
| - "n_layers_gpt2_small = 12\n", |
| 116 | + "def train_gpt_small_mlp_layers(\n", |
| 117 | + " expansion_factor: int = 4,\n", |
| 118 | + " n_layers: int = 12,\n", |
| 119 | + ") -> None:\n", |
| 120 | + " \"\"\"Run a new sweep experiment on GPT 2 Small's MLP layers.\n", |
130 | 121 | "\n",
|
131 |
| - "sweep_config = SweepConfig(\n", |
132 |
| - " parameters=Hyperparameters(\n", |
133 |
| - " activation_resampler=ActivationResamplerHyperparameters(\n", |
134 |
| - " resample_interval=Parameter(200_000_000),\n", |
135 |
| - " n_activations_activity_collate=Parameter(100_000_000),\n", |
136 |
| - " threshold_is_dead_portion_fires=Parameter(1e-6),\n", |
137 |
| - " max_n_resamples=Parameter(4),\n", |
138 |
| - " resample_dataset_size=Parameter(200_000),\n", |
139 |
| - " ),\n", |
140 |
| - " loss=LossHyperparameters(\n", |
141 |
| - " l1_coefficient=Parameter(max=1e-2, min=4e-3),\n", |
142 |
| - " ),\n", |
143 |
| - " optimizer=OptimizerHyperparameters(\n", |
144 |
| - " lr=Parameter(max=1e-3, min=1e-5),\n", |
145 |
| - " ),\n", |
146 |
| - " source_model=SourceModelHyperparameters(\n", |
147 |
| - " name=Parameter(\"gpt2-small\"),\n", |
148 |
| - " # Train in parallel on all MLP layers\n", |
149 |
| - " cache_names=Parameter(\n", |
150 |
| - " [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers_gpt2_small)]\n", |
| 122 | + " Args:\n", |
| 123 | + " expansion_factor: Expansion factor for the autoencoder.\n", |
| 124 | + " n_layers: Number of layers to train on. Max is 12.\n", |
| 125 | + "\n", |
| 126 | + " \"\"\"\n", |
| 127 | + " sweep_config = SweepConfig(\n", |
| 128 | + " parameters=Hyperparameters(\n", |
| 129 | + " loss=LossHyperparameters(\n", |
| 130 | + " l1_coefficient=Parameter(max=0.03, min=0.008),\n", |
| 131 | + " ),\n", |
| 132 | + " optimizer=OptimizerHyperparameters(\n", |
| 133 | + " lr=Parameter(max=0.001, min=0.00001),\n", |
| 134 | + " ),\n", |
| 135 | + " source_model=SourceModelHyperparameters(\n", |
| 136 | + " name=Parameter(\"gpt2\"),\n", |
| 137 | + " cache_names=Parameter(\n", |
| 138 | + " [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers)]\n", |
| 139 | + " ),\n", |
| 140 | + " hook_dimension=Parameter(768),\n", |
| 141 | + " ),\n", |
| 142 | + " source_data=SourceDataHyperparameters(\n", |
| 143 | + " dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n", |
| 144 | + " context_size=Parameter(256),\n", |
| 145 | + " pre_tokenized=Parameter(value=True),\n", |
| 146 | + " pre_download=Parameter(value=False), # Default to streaming the dataset\n", |
| 147 | + " ),\n", |
| 148 | + " autoencoder=AutoencoderHyperparameters(\n", |
| 149 | + " expansion_factor=Parameter(value=expansion_factor)\n", |
| 150 | + " ),\n", |
| 151 | + " pipeline=PipelineHyperparameters(\n", |
| 152 | + " max_activations=Parameter(1_000_000_000),\n", |
| 153 | + " checkpoint_frequency=Parameter(100_000_000),\n", |
| 154 | + " validation_frequency=Parameter(100_000_000),\n", |
| 155 | + " max_store_size=Parameter(1_000_000),\n", |
| 156 | + " ),\n", |
| 157 | + " activation_resampler=ActivationResamplerHyperparameters(\n", |
| 158 | + " resample_interval=Parameter(200_000_000),\n", |
| 159 | + " n_activations_activity_collate=Parameter(100_000_000),\n", |
| 160 | + " threshold_is_dead_portion_fires=Parameter(1e-6),\n", |
| 161 | + " max_n_resamples=Parameter(4),\n", |
151 | 162 | " ),\n",
|
152 |
| - " hook_dimension=Parameter(768),\n", |
153 |
| - " ),\n", |
154 |
| - " source_data=SourceDataHyperparameters(\n", |
155 |
| - " dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n", |
156 |
| - " context_size=Parameter(128),\n", |
157 |
| - " pre_tokenized=Parameter(value=True),\n", |
158 |
| - " ),\n", |
159 |
| - " pipeline=PipelineHyperparameters(\n", |
160 |
| - " max_activations=Parameter(1_000_000_000),\n", |
161 |
| - " checkpoint_frequency=Parameter(100_000_000),\n", |
162 |
| - " validation_frequency=Parameter(100_000_000),\n", |
163 |
| - " train_batch_size=Parameter(1024),\n", |
164 |
| - " max_store_size=Parameter(300_000),\n", |
165 | 163 | " ),\n",
|
166 |
| - " ),\n", |
167 |
| - " method=Method.RANDOM,\n", |
168 |
| - ")\n", |
169 |
| - "sweep_config" |
| 164 | + " method=Method.RANDOM,\n", |
| 165 | + " )\n", |
| 166 | + "\n", |
| 167 | + " sweep(sweep_config=sweep_config)" |
170 | 168 | ]
|
171 | 169 | },
|
172 | 170 | {
|
|
176 | 174 | "### Run the sweep"
|
177 | 175 | ]
|
178 | 176 | },
|
| 177 | + { |
| 178 | + "cell_type": "markdown", |
| 179 | + "metadata": {}, |
| 180 | + "source": [ |
| 181 | + "This will start a sweep with just one agent (the current machine). If you have multiple GPUs, it\n", |
| 182 | + "will use them automatically. Similarly it will work on Apple silicon devices by automatically using MPS." |
| 183 | + ] |
| 184 | + }, |
179 | 185 | {
|
180 | 186 | "cell_type": "code",
|
181 |
| - "execution_count": null, |
| 187 | + "execution_count": 4, |
182 | 188 | "metadata": {},
|
183 | 189 | "outputs": [],
|
184 | 190 | "source": [
|
185 |
| - "sweep(sweep_config=sweep_config)" |
| 191 | + "train_gpt_small_mlp_layers()" |
| 192 | + ] |
| 193 | + }, |
| 194 | + { |
| 195 | + "cell_type": "markdown", |
| 196 | + "metadata": {}, |
| 197 | + "source": [ |
| 198 | + "Want to speed things up? You can trivially add extra machines to the sweep, each of which will peel\n", |
| 199 | + "of some runs from the sweep agent (stored on Wandb). To do this, on another machine simply run:\n", |
| 200 | + "\n", |
| 201 | + "```bash\n", |
| 202 | + "pip install sparse_autoencoder\n", |
| 203 | + "join-sae-sweep --id=SWEEP_ID_SHOWN_ON_WANDB\n", |
| 204 | + "```" |
186 | 205 | ]
|
187 | 206 | }
|
188 | 207 | ],
|
|
0 commit comments