Skip to content

Reduce model complexity #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions docs/content/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
"# Quick Start Training Demo\n",
"\n",
"This is a quick start demo to get training a SAE right away. All you need to do is choose a few\n",
"hyperparameters (like the model to train on), and then set it off.\n",
"By default it replicates Neel Nanda's\n",
"[comment on the Anthropic dictionary learning\n",
"paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html#comment-nanda)."
"hyperparameters (like the model to train on), and then set it off. By default it trains SAEs on all\n",
"MLP layers from GPT2 small."
]
},
{
Expand Down Expand Up @@ -75,6 +73,7 @@
" Method,\n",
" OptimizerHyperparameters,\n",
" Parameter,\n",
" PipelineHyperparameters,\n",
" SourceDataHyperparameters,\n",
" SourceModelHyperparameters,\n",
" sweep,\n",
Expand Down Expand Up @@ -103,26 +102,40 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "SourceModelHyperparameters.__init__() got an unexpected keyword argument 'hook_layer'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 12\u001b[0m\n\u001b[1;32m 1\u001b[0m sweep_config \u001b[38;5;241m=\u001b[39m SweepConfig(\n\u001b[1;32m 2\u001b[0m parameters\u001b[38;5;241m=\u001b[39mHyperparameters(\n\u001b[1;32m 3\u001b[0m activation_resampler\u001b[38;5;241m=\u001b[39mActivationResamplerHyperparameters(\n\u001b[1;32m 4\u001b[0m threshold_is_dead_portion_fires\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;241m1e-6\u001b[39m),\n\u001b[1;32m 5\u001b[0m ),\n\u001b[1;32m 6\u001b[0m loss\u001b[38;5;241m=\u001b[39mLossHyperparameters(\n\u001b[1;32m 7\u001b[0m l1_coefficient\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-2\u001b[39m, \u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4e-3\u001b[39m),\n\u001b[1;32m 8\u001b[0m ),\n\u001b[1;32m 9\u001b[0m optimizer\u001b[38;5;241m=\u001b[39mOptimizerHyperparameters(\n\u001b[1;32m 10\u001b[0m lr\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-3\u001b[39m, \u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m),\n\u001b[1;32m 11\u001b[0m ),\n\u001b[0;32m---> 12\u001b[0m source_model\u001b[38;5;241m=\u001b[39m\u001b[43mSourceModelHyperparameters\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgelu-2l\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_names\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmlp_out\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[43mhook_layer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43mhook_dimension\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m512\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 18\u001b[0m source_data\u001b[38;5;241m=\u001b[39mSourceDataHyperparameters(\n\u001b[1;32m 19\u001b[0m dataset_path\u001b[38;5;241m=\u001b[39mParameter(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNeelNanda/c4-code-tokenized-2b\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 20\u001b[0m ),\n\u001b[1;32m 21\u001b[0m ),\n\u001b[1;32m 22\u001b[0m method\u001b[38;5;241m=\u001b[39mMethod\u001b[38;5;241m.\u001b[39mRANDOM,\n\u001b[1;32m 23\u001b[0m )\n\u001b[1;32m 24\u001b[0m sweep_config\n",
"\u001b[0;31mTypeError\u001b[0m: SourceModelHyperparameters.__init__() got an unexpected keyword argument 'hook_layer'"
]
"data": {
"text/plain": [
"SweepConfig(parameters=Hyperparameters(\n",
" source_data=SourceDataHyperparameters(dataset_path=Parameter(value=alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2), context_size=Parameter(value=128), dataset_column_name=Parameter(value=input_ids), dataset_dir=None, dataset_files=None, pre_download=Parameter(value=False), pre_tokenized=Parameter(value=True), tokenizer_name=None)\n",
" source_model=SourceModelHyperparameters(name=Parameter(value=gpt2-small), cache_names=Parameter(value=['blocks.0.hook_mlp_out', 'blocks.1.hook_mlp_out', 'blocks.2.hook_mlp_out', 'blocks.3.hook_mlp_out', 'blocks.4.hook_mlp_out', 'blocks.5.hook_mlp_out', 'blocks.6.hook_mlp_out', 'blocks.7.hook_mlp_out', 'blocks.8.hook_mlp_out', 'blocks.9.hook_mlp_out', 'blocks.10.hook_mlp_out', 'blocks.11.hook_mlp_out']), hook_dimension=Parameter(value=768), dtype=Parameter(value=float32))\n",
" activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_n_resamples=Parameter(value=4), n_activations_activity_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=200000), threshold_is_dead_portion_fires=Parameter(value=1e-06))\n",
" autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=2))\n",
" loss=LossHyperparameters(l1_coefficient=Parameter(max=0.01, min=0.004))\n",
" optimizer=OptimizerHyperparameters(lr=Parameter(max=0.001, min=1e-05), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n",
" pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=16), train_batch_size=Parameter(value=1024), max_store_size=Parameter(value=300000), max_activations=Parameter(value=1000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=100000000), validation_n_activations=Parameter(value=8192))\n",
" random_seed=Parameter(value=49)\n",
"), method=<Method.RANDOM: 'random'>, metric=Metric(name=train/loss/total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n_layers_gpt2_small = 12\n",
"\n",
"sweep_config = SweepConfig(\n",
" parameters=Hyperparameters(\n",
" activation_resampler=ActivationResamplerHyperparameters(\n",
" resample_interval=Parameter(200_000_000),\n",
" n_activations_activity_collate=Parameter(100_000_000),\n",
" threshold_is_dead_portion_fires=Parameter(1e-6),\n",
" max_n_resamples=Parameter(4),\n",
" resample_dataset_size=Parameter(200_000),\n",
" ),\n",
" loss=LossHyperparameters(\n",
" l1_coefficient=Parameter(max=1e-2, min=4e-3),\n",
Expand All @@ -131,12 +144,24 @@
" lr=Parameter(max=1e-3, min=1e-5),\n",
" ),\n",
" source_model=SourceModelHyperparameters(\n",
" name=Parameter(\"gelu-2l\"),\n",
" cache_names=Parameter([\"blocks.0.hook_mlp_out\", \"blocks.1.hook_mlp_out\"]),\n",
" hook_dimension=Parameter(512),\n",
" name=Parameter(\"gpt2-small\"),\n",
" # Train in parallel on all MLP layers\n",
" cache_names=Parameter(\n",
" [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers_gpt2_small)]\n",
" ),\n",
" hook_dimension=Parameter(768),\n",
" ),\n",
" source_data=SourceDataHyperparameters(\n",
" dataset_path=Parameter(\"NeelNanda/c4-code-tokenized-2b\"),\n",
" dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n",
" context_size=Parameter(128),\n",
" pre_tokenized=Parameter(value=True),\n",
" ),\n",
" pipeline=PipelineHyperparameters(\n",
" max_activations=Parameter(1_000_000_000),\n",
" checkpoint_frequency=Parameter(100_000_000),\n",
" validation_frequency=Parameter(100_000_000),\n",
" train_batch_size=Parameter(1024),\n",
" max_store_size=Parameter(300_000),\n",
" ),\n",
" ),\n",
" method=Method.RANDOM,\n",
Expand Down
146 changes: 21 additions & 125 deletions docs/content/flexible_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -83,17 +83,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: mps\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
Expand Down Expand Up @@ -140,7 +132,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -197,27 +189,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model gelu-2l into HookedTransformer\n"
]
},
{
"data": {
"text/plain": [
"'Source: gelu-2l, Hook: blocks.0.hook_mlp_out, Features: 512'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Source model setup with TransformerLens\n",
"src_model = HookedTransformer.from_pretrained(\n",
Expand Down Expand Up @@ -255,28 +229,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SparseAutoencoder(\n",
" (_pre_encoder_bias): TiedBias(position=pre_encoder)\n",
" (_encoder): LinearEncoder(\n",
" in_features=512, out_features=2048\n",
" (activation_function): ReLU()\n",
" )\n",
" (_decoder): UnitNormDecoder(in_features=2048, out_features=512)\n",
" (_post_decoder_bias): TiedBias(position=post_decoder)\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"expansion_factor = hyperparameters[\"expansion_factor\"]\n",
"autoencoder = SparseAutoencoder(\n",
Expand All @@ -297,23 +252,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LossReducer(\n",
" (0): LearnedActivationsL1Loss(l1_coefficient=0.0003)\n",
" (1): L2ReconstructionLoss()\n",
")"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# We use a loss reducer, which simply adds up the losses from the underlying loss functions.\n",
"loss = LossReducer(\n",
Expand All @@ -327,32 +268,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AdamWithReset (\n",
"Parameter Group 0\n",
" amsgrad: False\n",
" betas: (0.9, 0.999)\n",
" capturable: False\n",
" differentiable: False\n",
" eps: 1e-08\n",
" foreach: None\n",
" fused: None\n",
" lr: 0.0001\n",
" maximize: False\n",
" weight_decay: 0.0\n",
")"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"optimizer = AdamWithReset(\n",
" params=autoencoder.parameters(),\n",
Expand All @@ -361,6 +279,7 @@
" betas=(float(hyperparameters[\"adam_beta_1\"]), float(hyperparameters[\"adam_beta_2\"])),\n",
" eps=float(hyperparameters[\"adam_epsilon\"]),\n",
" weight_decay=float(hyperparameters[\"adam_weight_decay\"]),\n",
" has_components_dim=True,\n",
")\n",
"optimizer"
]
Expand All @@ -374,7 +293,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -403,27 +322,13 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2fe4955deca9463dbed606c9452d518e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"source_data = PreTokenizedDataset(\n",
" dataset_path=\"NeelNanda/c4-code-tokenized-2b\", context_size=int(hyperparameters[\"context_size\"])\n",
" dataset_path=\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\",\n",
" context_size=int(hyperparameters[\"context_size\"]),\n",
")"
]
},
Expand All @@ -447,7 +352,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -471,14 +376,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"pipeline = Pipeline(\n",
" activation_resampler=activation_resampler,\n",
" autoencoder=autoencoder,\n",
" cache_name=str(hyperparameters[\"source_model_hook_point\"]),\n",
" cache_names=[str(hyperparameters[\"source_model_hook_point\"])],\n",
" checkpoint_directory=checkpoint_path,\n",
" layer=int(hyperparameters[\"source_model_hook_point_layer\"]),\n",
" loss=loss,\n",
Expand All @@ -496,15 +401,6 @@
" validate_frequency=int(hyperparameters[\"validation_frequency\"]),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wandb.finish()"
]
}
],
"metadata": {
Expand Down
6 changes: 1 addition & 5 deletions docs/content/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,4 @@ The library is designed to be modular. By default it takes the approach from [To
Monosemanticity: Decomposing Language Models With Dictionary Learning
](https://transformer-circuits.pub/2023/monosemantic-features/index.html), so you can pip install
the library and get started quickly. Then when you need to customise something, you can just extend
the abstract class for that component (e.g. you can extend
[`AbstractEncoder`][sparse_autoencoder.autoencoder.components.abstract_encoder] if you want to
customise the encoder layer, and then easily drop it in the standard
[`SparseAutoencoder`][sparse_autoencoder.autoencoder.model] model to keep everything else as is.
Every component is fully documented, so it's nice and easy to do this.
the abstract class for that component (every component is documented so that it's easy to do this).
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def test_updates_dead_neuron_parameters(
# Check the updated ones have changed
for component_idx, neuron_idx in dead_neurons:
# Decoder
decoder_weights = current_parameters["_decoder._weight"]
decoder_weights = current_parameters["decoder._weight"]
current_dead_neuron_weights = decoder_weights[component_idx, neuron_idx]
updated_dead_decoder_weights = parameter_updates[
component_idx
Expand All @@ -353,7 +353,7 @@ def test_updates_dead_neuron_parameters(
), "Dead decoder weights should have changed."

# Encoder
current_dead_encoder_weights = current_parameters["_encoder._weight"][
current_dead_encoder_weights = current_parameters["encoder._weight"][
component_idx, neuron_idx
]
updated_dead_encoder_weights = parameter_updates[
Expand All @@ -363,7 +363,7 @@ def test_updates_dead_neuron_parameters(
current_dead_encoder_weights, updated_dead_encoder_weights
), "Dead encoder weights should have changed."

current_dead_encoder_bias = current_parameters["_encoder._bias"][
current_dead_encoder_bias = current_parameters["encoder._bias"][
component_idx, neuron_idx
]
updated_dead_encoder_bias = parameter_updates[component_idx].dead_encoder_bias_updates
Expand Down
Loading