Skip to content

Official repository of "Contextualize-then-Aggregate: Circuits for In-Context Learning in Gemma-2 2B" paper

License

Notifications You must be signed in to change notification settings

lacoco-lab/icl_circuits

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Contextualize-then-Aggregate: Circuits for In-Context Learning in Gemma-2 2B

This is the official repository of the COLM 2025 paper "Contextualize-then-Aggregate: Circuits for In-Context Learning in Gemma-2 2B".

We used this code for patching edges inside attention heads in Gemma2-2b, Gemma2-9b, Gemma2-27b, Phi2, Smollm, Llama3.

Getting Started

Dependencies

The code here is specific to the version of transformers library. We used transormers==4.44, the code might work in other versions as well, but we did not test it. We changed forward passes of models from transformers library, thus if the signature of .forward() function changed, chances are high that this code needs to be changed to work in a new version.

If you want to use the same libraries we used, try installing the them from environment.yml. You can copy our environment via conda env create -f environment.yml and use the environment called gemma2.

Directory structure

  • models/ - contains implementation for models: each model is a combination of model (e.g., Llama3, Gemma2-2b) and what you want to do (e.g. ablate some edges, calculate accuracy on a task, etc). models/__init__.py contains a dict of all available model names is stored in NAME_TO_MODEL dict.
  • tasks/ - contain implementation for tasks: each task file is responsible for constructing an in-context learning task out of a set of input-output pairs, including choice of examples to include in a prompt, prompt formatting, construction of corrupted examples and division by token types. tasks/__init__.py contains a dict of all available task names is stored in NAME_TO_TASK dict.
  • datasets/ - contains dataset files (input-output pairs) for all tasks. They are different for different models, because we need to make sure that only the pairs that get tokenized correctly (separators are tokenized as separate tokens from meaningful words).
  • visualization/ - contains a notebook with code for creating visualizations from the paper.
  • run.py - entrypoint with examples of configuration files. We used different configurations in the paper, here we provide simple versions for quick testing.

Executing a program

  1. Clone the repository
  2. Modify run.py file via adding a new instance to CONFIG dict. A Config class contains of the following fields:
    • task: str : REQUIRED, a task name to execute. Each task is basically a dataset, a task class is responsible for data, prompt formatting, getting corrupted prompts for each prompt, and getting token types for each token. A dict of all available task names is stored in NAME_TO_TASK dict, which is initialized in tasks/__init__.py. The actual names of the tasks you can find in tasks/task_initialization_file.py in get_name() function.
    • model_name: str : REQUIRED, a model name to execute. Each model is a combination of model (e.g., Llama3, Gemma2-2b) and what you want to do (e.g. ablate some edges, calculate accuracy on a task, etc). A dict of all available model names is stored in NAME_TO_MODEL dict, which is initialized in models/__init__.py.
    • batch_size: int : REQUIRED, self-explanatory
    • output_path: str : REQUIRED, a path to a directory where we will create output files. If directory does not exist, we will create it, if there are files with the same names as we will create, they will be overwritten.
    • limit: Optional[int] = None : Max number of prompts to run the model on. If None, run on the whole task as defined in the task file in tasks/ directory.
    • prune_using_imp_scores: Optional[List[str]] = None : Required for pruning models, path to files containing importance scores according to which to prune the models. Since we prune models iteratively (the algorithm is: calculate importance scores, prune 10% of the heads/edges based on these importance scores, repeat), these files will be created by previous runs of the model. Not needed for the first run of the model, when we don't prune anything before running the model.
    • prune_k: Optional[List[int]] = None : Required for pruning models, the number of heads/edges to leave after the corresponding pruning round. Length of this list must be the same as of prune_using_imp_scores list. In pruning round i we prune prune_k[i] heads/edges based on importance scores from prune_using_imp_scores[i] file, and then we run the model on the task after all pruning rounds.
    • affect_whom: list = None : Required for some models, and means different things depending on a model, see model description for more info. Usually means something like "prune the edges between types listed in this list".

Example of Config instance for each model is in run.py file.

Suppose you named your config "my_awesome_config" and now CONFIG["my_awesome_config"] contains a list of Config instances you want to execute.

  1. Run the models with python run.py --config "my_awesome_config'. Do not forget to also do export PYTHONPATH="path_to_the_dircory_with_repository"

  2. The results will be in output_dir/logging_dir/model_logs.json file. Look for an accuracy field for accuracy, there are many other logs as well (for their description contact a model file in models/ directory).

There are a few configs as examples already, for a start you could try them.

Supported models

  • *_ablate_edges:
    • Optional arguments required: affect_whom: a list of pairs (token_type_1, token_type_2) - ablate all edges from token_type_1 to token_type_2.
    • * is a base model name: one of gemma2, phi2 or smollm
    • Run the model with all edges in affect_whom list "removed", i.e. replaced with activation on a corrupted example. Corrupted example is different for every prompt.
  • *_calc_accuracy:
    • * is a base model name: one of gemma2, phi2 or smollm
    • Just run the full model and get full model accuracy.
  • gemma2_ablate_edges_differently_for_different_tokens:
    • Optional arguments required: affect_whom: a dict that contains fields "ablate" and "curcuit". affect_whom["curcuit"] is a dict that maps pairs (token_type_1, token_type_2) to 0, meaning that we will ablate all edges from token_type_1 to token_type_2 with corruption 0 from the list of corruptions. affect_whom["ablate"] is a dict that maps pairs (token_type_1, token_type_2) to an int > 0, meaning that we will ablate edges from token_type_1 to token_type_2 with corruption from the list of corruptions specific to this edge. List of corruptions is specified in task class. You need to choose a task that has at least as many different corruptions, as maximum number in affect_whom["ablate"] + 1.
    • Run the model with all edges in affect_whom["curcuit"] list "removed with corruption 0", i.e. replaced with activation on a first corrupted example, and all activations affect_whom["ablate"] "removed with corruption x inside a circuit", i.e. replaced with activation on a example where edges from "circuit" are also ablated with corruption 0. List of corrupted examples is different for every prompt.
  • gemma2_prune_heads_in_circuits:
    • Optional arguments required:
      • prune_using_imp_scores: list of paths to importance scores according to which to prune. should lead to outputs of previous runs of this model on this task.
      • prune_k: list of heads to leave after each pruning round, should be of the same length as prune_using_imp_scores list. What will happen is that we will leave prune_k[0] heads with biggest importance scores according to scores in prune_using_imp_scores[0] file, then prune_k[1] heads with biggest importance scores according to scores in prune_using_imp_scores[1] file, and so on. prune_k should thus also be sorted in non-increasing order. Head is a tuple (layer, head, token_type)
      • affect_whom: same as in *_ablate_edges models, edges to prune before staring the iterative pruning procedure.
    • Do iterative importance-scores-based pruning of heads in a model, where edges outside the "circuit" are already pruned.
  • gemma2_prune_edges_in_circuits:
    • Same as gemma2_prune_heads_in_circuits, but here we prune edges. Should be run after gemma2_prune_heads_in_circuits, and thus its prune_using_imp_scores and prune_k lists should contain first values for head pruning, and then for edges pruning. Can be run without heads pruning as well.

Supported tasks

We support five tasks from the paper: Copying, Country-Capital, Present-Past, Person-Sport and Capitalization. There are also files for Ambiguous Present-Past task.

The dataset we use for each task depends on the tokenizer of the model, since we need to ensure that the items in dataset are tokenized correctly: each separator is tokenized separately from the "meaningful" tokens.

Corruption datasets are usually listed in the task file in tasks/ directory.

We also assign each token in a dataset its type: one of $x_i$, $y_i$, $t_i$, $n_i$, query, target, bos/eos, last_separator. For more information, contact the paper text or the task file in tasks/ directory.

Modifying

How to add more tasks

All the tasks are In-Context-Learning-formatted tasks and are in tasks directory. To add a new task, do the following:

  1. If your task requires a new dataset of input-output pairs, add it to datasets directory in huggingface format (via datset.save_to_disk('dataset_name')). If it just a new formatting / new corruption type of the existing dataset, then this step is not needed.
  2. Create a new file in tasks directory. This file should contain a class with your task name and should be derived from parent Task class.
  3. Fill in all the data for your task in analogy with other task files in the folder.
  4. Add a get_name method with string name of your task (or set of tasks).
  5. Add your new class to NAME_TO_TASK dictionary in tasks/__init__.py

How to add more models

All the models are in models directory. To add a new model, do the following:

  1. If this is a fundamentally new model (i.e. Qwen model, which has a huggingface class different from Gemma2, Phi2 or LLama3), then add a new file to models directory in analogy to gemma2.py file. It should contain a base class, i.e. Qwen and should download the new model from huggingface.
  2. Create a new file in models directory, and create a new class there derived from the base model class created in the previous step (or taken from one of the existent ones). Create a run() method. Make sure its signature is consistent with the fields we have in Config class in run.py file.
  3. Add your new model class to NAME_TO_MODEL dict in models/__init__.py.

License

This project is licensed under the MIT License.

Cite

@article{bakalova2025contextualize,
  title={Contextualize-then-Aggregate: Circuits for In-Context Learning in Gemma-2 2B},
  author={Bakalova, Aleksandra and Veitsman, Yana and Huang, Xinting and Hahn, Michael},
  journal={arXiv preprint arXiv:2504.00132},
  year={2025}
}

About

Official repository of "Contextualize-then-Aggregate: Circuits for In-Context Learning in Gemma-2 2B" paper

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published