This is a beta release (public testing).
pyvene supports customizable interventions on different neural architectures (e.g., RNN or Transformers). It supports complex intervention schemas (e.g., parallel or serialized interventions) and a wide range of intervention modes (e.g., static or trained interventions) at scale to gain interpretability insights.
pip install pyvene
You can intervene with supported models as,
import pyvene
from pyvene import IntervenableRepresentationConfig, IntervenableConfig, IntervenableModel
from pyvene import VanillaIntervention
# provided wrapper for huggingface gpt2 model
_, tokenizer, gpt2 = pyvene.create_gpt2()
# turn gpt2 into intervenable_gpt2
intervenable_gpt2 = IntervenableModel(
intervenable_config = IntervenableConfig(
intervenable_representations=[
IntervenableRepresentationConfig(
0, # intervening layer 0
"mlp_output", # intervening mlp output
),
],
intervenable_interventions_type=VanillaIntervention
),
model = gpt2
)
# intervene base with sources on the 4th token.
original_outputs, intervened_outputs = intervenable_gpt2(
tokenizer("The capital of Spain is", return_tensors="pt"),
[tokenizer("The capital of Italy is", return_tensors="pt")],
{"sources->base": 4}
)
original_outputs.last_hidden_state - intervened_outputs.last_hidden_state
which returns,
tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0008, -0.0078, -0.0066, ..., 0.0007, -0.0018, 0.0060]]])
showing that we have causal effects only on the last token as expected. You can share your interventions through Huggingface with others with a single call,
intervenable_gpt2.save(
save_directory="./your_gpt2_mounting_point/",
save_to_hf_hub=True,
hf_repo_name="your_gpt2_mounting_point",
)
We see interventions are knobs that can mount on models. And people can share their knobs with others to share knowledge about how to steer models. You can try this at [Intervention Sharing]
You can also use the intervenable_gpt2
just like a regular torch model component inside another model, or another pipeline as,
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union, Dict
class ModelWithIntervenables(nn.Module):
def __init__(self):
super(ModelWithIntervenables, self).__init__()
self.intervenable_gpt2 = intervenable_gpt2
self.relu = nn.ReLU()
self.fc = nn.Linear(768, 1)
# Your other downstream components go here
def forward(
self,
base,
sources: Optional[List] = None,
unit_locations: Optional[Dict] = None,
activations_sources: Optional[Dict] = None,
subspaces: Optional[List] = None,
):
_, counterfactual_x = self.intervenable_gpt2(
base,
sources,
unit_locations,
activations_sources,
subspaces
)
counterfactual_x = counterfactual_x.last_hidden_state
counterfactual_x = self.relu(counterfactual_x)
counterfactual_x = self.fc(counterfactual_x)
return counterfactual_x
Level | Tutorial | Run in Colab | Description |
---|---|---|---|
Beginner | Getting Started | Introduces basic static intervention on factual recall examples | |
Beginner | Intervened Model Generation | Shows how to intervene a model during generation | |
Intermediate | Intervene Your Local Models | Illustrates how to run this library with your own models | |
Intermediate | ROME Causal Tracing | Reproduce ROME's Results on Factual Associations with GPT2-XL | |
Intermediate | Intervention v.s. Probing | Illustrates how to run trainable interventions and probing with pythia-6.9B | |
Advanced | Trainable Interventions for Causal Abstraction | Illustrates how to train an intervention to discover causal mechanisms of a neural model |
Basic interventions are fun but we cannot make any causal claim systematically. To gain actual interpretability insights, we want to measure the counterfactual behaviors of a model in a data-driven fashion. In other words, if the model responds systematically to your interventions, then you start to associate certain regions in the network with a high-level concept. We also call this alignment search process with model internals.
Here is a more concrete example,
def add_three_numbers(a, b, c):
var_x = a + b
return var_x + c
The function solves a 3-digit sum problem. Let's say, we trained a neural network to solve this problem perfectly. "Can we find the representation of (a + b) in the neural network?". We can use this library to answer this question. Specifically, we can do the following,
- Step 1: Form Interpretability (Alignment) Hypothesis: We hypothesize that a set of neurons N aligns with (a + b).
- Step 2: Counterfactual Testings: If our hypothesis is correct, then swapping neurons N between examples would give us expected counterfactual behaviors. For instance, the values of N for (1+2)+3, when swapping with N for (2+3)+4, the output should be (2+3)+3 or (1+2)+4 depending on the direction of the swap.
- Step 3: Reject Sampling of Hypothesis: Running tests multiple times and aggregating statistics in terms of counterfactual behavior matching. Proposing a new hypothesis based on the results.
To translate the above steps into API calls with the library, it will be a single call,
intervenable.evaluate(
train_dataloader=test_dataloader,
compute_metrics=compute_metrics,
inputs_collator=inputs_collator
)
where you provide testing data (basically interventional data and the counterfactual behavior you are looking for) along with your metrics functions. The library will try to evaluate the alignment with the intervention you specified in the config.
The alignment searching process outlined above can be tedious when your neural network is large. For a single hypothesized alignment, you basically need to set up different intervention configs targeting different layers and positions to verify your hypothesis. Instead of doing this brute-force search process, you can turn it into an optimization problem which also has other benefits such as distributed alignments.
In its crux, we basically want to train an intervention to have our desired counterfactual behaviors in mind. And if we can indeed train such interventions, we claim that causally informative information should live in the intervening representations! Below, we show one type of trainable intervention models.interventions.RotatedSpaceIntervention
as,
class RotatedSpaceIntervention(TrainableIntervention):
"""Intervention in the rotated space."""
def forward(self, base, source):
rotated_base = self.rotate_layer(base)
rotated_source = self.rotate_layer(source)
# interchange
rotated_base[:self.interchange_dim] = rotated_source[:self.interchange_dim]
# inverse base
output = torch.matmul(rotated_base, self.rotate_layer.weight.T)
return output
Instead of activation swapping in the original representation space, we first rotate them, and then do the swap followed by un-rotating the intervened representation. Additionally, we try to use SGD to learn a rotation that lets us produce expected counterfactual behavior. If we can find such rotation, we claim there is an alignment. If the cost is between X and Y.ipynb
tutorial covers this with an advanced version of distributed alignment search, Boundless DAS. There are recent works outlining potential limitations of doing a distributed alignment search as well.
You can now also make a single API call to train your intervention,
intervenable.train(
train_dataloader=train_dataloader,
compute_loss=compute_loss,
compute_metrics=compute_metrics,
inputs_collator=inputs_collator
)
where you need to pass in a trainable dataset, and your customized loss and metrics function. The trainable interventions can later be saved on to your disk. You can also use intervenable.evaluate()
your interventions in terms of customized objectives.
Please see our guidelines about how to contribute to this repository.
Pull requests, bug reports, and all other forms of contribution are welcomed and highly encouraged!
Method 2: Install from the Repo
pip install git+https://github.com/stanfordnlp/pyvene.git
Method 3: Clone and Import
git clone https://github.com/stanfordnlp/pyvene.git
and in parallel folder, import to your project as,
from pyvene import pyvene
_, tokenizer, gpt2 = pyvene.create_gpt2()
If you would like to read more works on this area, here is a list of papers that try to align or discover the causal mechanisms of LLMs.
- Causal Abstractions of Neural Networks: This paper introduces interchange intervention (a.k.a. activation patching or causal scrubbing). It tries to align a causal model with the model's representations.
- Inducing Causal Structure for Interpretable Neural Networks: Interchange intervention training (IIT) induces causal structures into the model's representations.
- Localizing Model Behavior with Path Patching: Path patching (or causal scrubbing) to uncover causal paths in neural model.
- Towards Automated Circuit Discovery for Mechanistic Interpretability: Scalable method to prune out a small set of connections in a neural network that can still complete a task.
- Interpretability in the Wild: a Circuit for Indirect Object Identification in GPT-2 small: Path patching plus posthoc representation study to uncover a circuit that solves the indirect object identification (IOI) task.
- Rigorously Assessing Natural Language Explanations of Neurons: Using causal abstraction to validate neuron explanations released by OpenAI.
Library paper is forthcoming. For now, if you use this repository, please consider to cite relevant papers:
@article{geiger-etal-2023-DAS,
title={Finding Alignments Between Interpretable Causal Variables and Distributed Neural Representations},
author={Geiger, Atticus and Wu, Zhengxuan and Potts, Christopher and Icard, Thomas and Goodman, Noah},
year={2023},
booktitle={arXiv}
}
@article{wu-etal-2023-Boundless-DAS,
title={Interpretability at Scale: Identifying Causal Mechanisms in Alpaca},
author={Wu, Zhengxuan and Geiger, Atticus and Icard, Thomas and Potts, Christopher and Goodman, Noah},
year={2023},
booktitle={NeurIPS}
}