This is the official repository of the COLM 2025 paper "Contextualize-then-Aggregate: Circuits for In-Context Learning in Gemma-2 2B".
We used this code for patching edges inside attention heads in Gemma2-2b, Gemma2-9b, Gemma2-27b, Phi2, Smollm, Llama3.
The code here is specific to the version of transformers library. We used transormers==4.44, the code might work in other versions as well, but we did not test it. We changed forward passes of models from transformers library, thus if the signature of .forward() function changed, chances are high that this code needs to be changed to work in a new version.
If you want to use the same libraries we used, try installing the them from environment.yml. You can copy our environment via conda env create -f environment.yml and use the environment called gemma2.
models/- contains implementation for models: each model is a combination of model (e.g., Llama3, Gemma2-2b) and what you want to do (e.g. ablate some edges, calculate accuracy on a task, etc).models/__init__.pycontains a dict of all available model names is stored inNAME_TO_MODELdict.tasks/- contain implementation for tasks: each task file is responsible for constructing an in-context learning task out of a set of input-output pairs, including choice of examples to include in a prompt, prompt formatting, construction of corrupted examples and division by token types.tasks/__init__.pycontains a dict of all available task names is stored inNAME_TO_TASKdict.datasets/- contains dataset files (input-output pairs) for all tasks. They are different for different models, because we need to make sure that only the pairs that get tokenized correctly (separators are tokenized as separate tokens from meaningful words).visualization/- contains a notebook with code for creating visualizations from the paper.run.py- entrypoint with examples of configuration files. We used different configurations in the paper, here we provide simple versions for quick testing.
- Clone the repository
- Modify
run.pyfile via adding a new instance toCONFIGdict. AConfigclass contains of the following fields:task: str: REQUIRED, a task name to execute. Each task is basically a dataset, a task class is responsible for data, prompt formatting, getting corrupted prompts for each prompt, and getting token types for each token. A dict of all available task names is stored inNAME_TO_TASKdict, which is initialized intasks/__init__.py. The actual names of the tasks you can find in tasks/task_initialization_file.py in get_name() function.model_name: str: REQUIRED, a model name to execute. Each model is a combination of model (e.g., Llama3, Gemma2-2b) and what you want to do (e.g. ablate some edges, calculate accuracy on a task, etc). A dict of all available model names is stored inNAME_TO_MODELdict, which is initialized inmodels/__init__.py.batch_size: int: REQUIRED, self-explanatoryoutput_path: str: REQUIRED, a path to a directory where we will create output files. If directory does not exist, we will create it, if there are files with the same names as we will create, they will be overwritten.limit: Optional[int] = None: Max number of prompts to run the model on. If None, run on the whole task as defined in the task file intasks/directory.prune_using_imp_scores: Optional[List[str]] = None: Required for pruning models, path to files containing importance scores according to which to prune the models. Since we prune models iteratively (the algorithm is: calculate importance scores, prune 10% of the heads/edges based on these importance scores, repeat), these files will be created by previous runs of the model. Not needed for the first run of the model, when we don't prune anything before running the model.prune_k: Optional[List[int]] = None: Required for pruning models, the number of heads/edges to leave after the corresponding pruning round. Length of this list must be the same as of prune_using_imp_scores list. In pruning round i we prune prune_k[i] heads/edges based on importance scores from prune_using_imp_scores[i] file, and then we run the model on the task after all pruning rounds.affect_whom: list = None: Required for some models, and means different things depending on a model, see model description for more info. Usually means something like "prune the edges between types listed in this list".
Example of Config instance for each model is in run.py file.
Suppose you named your config "my_awesome_config" and now CONFIG["my_awesome_config"] contains a list of Config instances you want to execute.
-
Run the models with
python run.py --config "my_awesome_config'. Do not forget to also doexport PYTHONPATH="path_to_the_dircory_with_repository" -
The results will be in
output_dir/logging_dir/model_logs.jsonfile. Look for anaccuracyfield for accuracy, there are many other logs as well (for their description contact a model file inmodels/directory).
There are a few configs as examples already, for a start you could try them.
*_ablate_edges:- Optional arguments required:
affect_whom: a list of pairs (token_type_1, token_type_2) - ablate all edges from token_type_1 to token_type_2. *is a base model name: one of gemma2, phi2 or smollm- Run the model with all edges in affect_whom list "removed", i.e. replaced with activation on a corrupted example. Corrupted example is different for every prompt.
- Optional arguments required:
*_calc_accuracy:*is a base model name: one of gemma2, phi2 or smollm- Just run the full model and get full model accuracy.
gemma2_ablate_edges_differently_for_different_tokens:- Optional arguments required:
affect_whom: a dict that contains fields "ablate" and "curcuit".affect_whom["curcuit"]is a dict that maps pairs (token_type_1, token_type_2) to 0, meaning that we will ablate all edges from token_type_1 to token_type_2 with corruption 0 from the list of corruptions.affect_whom["ablate"]is a dict that maps pairs (token_type_1, token_type_2) to an int > 0, meaning that we will ablate edges from token_type_1 to token_type_2 with corruption from the list of corruptions specific to this edge. List of corruptions is specified in task class. You need to choose a task that has at least as many different corruptions, as maximum number inaffect_whom["ablate"]+ 1. - Run the model with all edges in affect_whom["curcuit"] list "removed with corruption 0", i.e. replaced with activation on a first corrupted example, and all activations affect_whom["ablate"] "removed with corruption x inside a circuit", i.e. replaced with activation on a example where edges from "circuit" are also ablated with corruption 0. List of corrupted examples is different for every prompt.
- Optional arguments required:
gemma2_prune_heads_in_circuits:- Optional arguments required:
prune_using_imp_scores: list of paths to importance scores according to which to prune. should lead to outputs of previous runs of this model on this task.prune_k: list of heads to leave after each pruning round, should be of the same length asprune_using_imp_scoreslist. What will happen is that we will leaveprune_k[0]heads with biggest importance scores according to scores inprune_using_imp_scores[0]file, thenprune_k[1]heads with biggest importance scores according to scores inprune_using_imp_scores[1]file, and so on.prune_kshould thus also be sorted in non-increasing order. Head is a tuple (layer, head, token_type)affect_whom: same as in*_ablate_edgesmodels, edges to prune before staring the iterative pruning procedure.
- Do iterative importance-scores-based pruning of heads in a model, where edges outside the "circuit" are already pruned.
- Optional arguments required:
gemma2_prune_edges_in_circuits:- Same as
gemma2_prune_heads_in_circuits, but here we prune edges. Should be run aftergemma2_prune_heads_in_circuits, and thus itsprune_using_imp_scoresandprune_klists should contain first values for head pruning, and then for edges pruning. Can be run without heads pruning as well.
- Same as
We support five tasks from the paper: Copying, Country-Capital, Present-Past, Person-Sport and Capitalization. There are also files for Ambiguous Present-Past task.
The dataset we use for each task depends on the tokenizer of the model, since we need to ensure that the items in dataset are tokenized correctly: each separator is tokenized separately from the "meaningful" tokens.
Corruption datasets are usually listed in the task file in tasks/ directory.
We also assign each token in a dataset its type: one of tasks/ directory.
All the tasks are In-Context-Learning-formatted tasks and are in tasks directory. To add a new task, do the following:
- If your task requires a new dataset of input-output pairs, add it to
datasetsdirectory in huggingface format (viadatset.save_to_disk('dataset_name')). If it just a new formatting / new corruption type of the existing dataset, then this step is not needed. - Create a new file in
tasksdirectory. This file should contain a class with your task name and should be derived from parentTaskclass. - Fill in all the data for your task in analogy with other task files in the folder.
- Add a
get_namemethod with string name of your task (or set of tasks). - Add your new class to
NAME_TO_TASKdictionary intasks/__init__.py
All the models are in models directory. To add a new model, do the following:
- If this is a fundamentally new model (i.e. Qwen model, which has a huggingface class different from Gemma2, Phi2 or LLama3), then add a new file to
modelsdirectory in analogy togemma2.pyfile. It should contain a base class, i.e.Qwenand should download the new model from huggingface. - Create a new file in
modelsdirectory, and create a new class there derived from the base model class created in the previous step (or taken from one of the existent ones). Create arun()method. Make sure its signature is consistent with the fields we have inConfigclass inrun.pyfile. - Add your new model class to
NAME_TO_MODELdict inmodels/__init__.py.
This project is licensed under the MIT License.
@article{bakalova2025contextualize,
title={Contextualize-then-Aggregate: Circuits for In-Context Learning in Gemma-2 2B},
author={Bakalova, Aleksandra and Veitsman, Yana and Huang, Xinting and Hahn, Michael},
journal={arXiv preprint arXiv:2504.00132},
year={2025}
}