Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ on:
HF_TOKEN:
required: false


env:
HF_HUB_DOWNLOAD_TIMEOUT: 120

permissions:
actions: write
contents: write
Expand Down Expand Up @@ -211,6 +215,7 @@ jobs:
run: make coverage-report-test
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_HUB_DOWNLOAD_TIMEOUT: 120
- name: Build check
run: uv build
- name: Upload Coverage Report Artifact
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ docs/source/generated
# docs/source/_static/model_table
**.orig
.venv

.env
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"notebook.formatOnSave.enabled": true,
"pylint.importStrategy": "fromEnvironment",
"python.testing.pytestArgs": [
"transformer_lens",
"tests"
],
"python.testing.pytestEnabled": true,
"rewrap.autoWrap.enabled": true,
Expand Down
27 changes: 9 additions & 18 deletions demos/ARENA_Content.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"source": [
"\n",
"# [1.1] Transformer From Scratch\n",
"# 1️⃣ UNDERSTANDING INPUTS & OUTPUTS OF A TRANSFORMER\n",
"# 1\ufe0f\u20e3 UNDERSTANDING INPUTS & OUTPUTS OF A TRANSFORMER\n",
"\n",
"sorted_vocab = sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda n: n[1])\n",
"first_vocab = sorted_vocab[0]\n",
Expand Down Expand Up @@ -235,23 +235,14 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' I'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"\n",
"most_likely_next_tokens = reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])\n",
"most_likely_next_tokens = reference_gpt2.tokenizer.batch_decode(\n",
" [[int(t)] for t in logits.argmax(dim=-1)[0].tolist()]\n",
")\n",
"most_likely_next_tokens[-1]\n",
"\n"
]
Expand Down Expand Up @@ -285,7 +276,7 @@
}
],
"source": [
"# 2️⃣ CLEAN TRANSFORMER IMPLEMENTATION\n",
"# 2\ufe0f\u20e3 CLEAN TRANSFORMER IMPLEMENTATION\n",
"\n",
"layer_0_hooks = [\n",
" (name, tuple(tensor.shape)) for name, tensor in cache.items() if \".0.\" in name\n",
Expand Down Expand Up @@ -371,7 +362,7 @@
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"# [1.2] Intro to mech interp\n",
"# 2️⃣ FINDING INDUCTION HEADS\n",
"# 2\ufe0f\u20e3 FINDING INDUCTION HEADS\n",
"\n",
"if not IN_GITHUB:\n",
" # Cannot run in CI\n",
Expand Down Expand Up @@ -449,4 +440,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
50 changes: 27 additions & 23 deletions makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
RUN := uv run

# Rerun args for flaky tests (httpx timeouts during HF Hub downloads)
# Remove this line when no longer needed
RERUN_ARGS := --reruns 2 --reruns-delay 5

dep:
uv sync

Expand All @@ -14,45 +18,45 @@ check-format:
$(RUN) black --check .

unit-test:
$(RUN) pytest tests/unit
$(RUN) pytest tests/unit $(RERUN_ARGS)

integration-test:
$(RUN) pytest tests/integration
$(RUN) pytest tests/integration $(RERUN_ARGS)

acceptance-test:
$(RUN) pytest tests/acceptance
$(RUN) pytest tests/acceptance $(RERUN_ARGS)

benchmark-test:
$(RUN) pytest tests/benchmarks
$(RUN) pytest tests/benchmarks $(RERUN_ARGS)

coverage-report-test:
$(RUN) pytest --cov=transformer_lens/ --cov-report=html --cov-branch tests/integration tests/benchmarks tests/unit tests/acceptance
$(RUN) pytest --cov=transformer_lens/ --cov-report=html --cov-branch tests/integration tests/benchmarks tests/unit tests/acceptance $(RERUN_ARGS)

docstring-test:
$(RUN) pytest transformer_lens/
$(RUN) pytest transformer_lens/ $(RERUN_ARGS)

notebook-test:
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb

$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Head_Detector_Demo.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Interactive_Neuroscope.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/LLaMA.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/No_Position_Experiment.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Othello_GPT.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Qwen.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Santa_Coder.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Stable_Lm.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/SVD_Interpreter_Demo.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Tracr_to_Transformer_Lens_Demo.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb $(RERUN_ARGS)

$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Head_Detector_Demo.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Interactive_Neuroscope.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/LLaMA.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/No_Position_Experiment.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Othello_GPT.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Qwen.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Santa_Coder.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Stable_Lm.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/SVD_Interpreter_Demo.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Tracr_to_Transformer_Lens_Demo.ipynb $(RERUN_ARGS)

# Contains failing cells

# Causes CI to hang
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb $(RERUN_ARGS)
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb $(RERUN_ARGS)

test:
make unit-test
Expand Down
37 changes: 20 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"pytest>=7.2.0",
"pytest-cov>=4.0.0",
"pytest-doctestplus>=1.0.0",
"pytest-rerunfailures>=12.0",
"pytest-xdist>=3.8.0",
]
jupyter = [
Expand Down Expand Up @@ -102,6 +103,7 @@
pytest=">=7.2.0"
pytest-cov=">=4.0.0"
pytest-doctestplus="^1.0.0"
pytest-rerunfailures=">=12.0"
pytest-xdist="^3.8.0" # For parallel test execution

[tool.poetry.group.jupyter.dependencies]
Expand All @@ -112,23 +114,24 @@
default-groups=["dev", "jupyter", "docs"]

[tool.pytest]
[tool.pytest.ini_options]
addopts=[
"--doctest-modules",
"--doctest-plus",
"--jaxtyping-packages=transformer_lens,beartype.beartype",
"--nbval",
"-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning",
]
testpaths = ["tests", "transformer_lens"] # Only test these directories
doctest_optionflags="NORMALIZE_WHITESPACE ELLIPSIS FLOAT_CMP"
filterwarnings=[
"ignore:pkg_resources is deprecated as an API:DeprecationWarning",
# Ignore numpy.distutils deprecation warning caused by pandas
# More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils
"ignore:distutils Version classes are deprecated:DeprecationWarning",
]
pythonpath=["."]

[tool.pytest.ini_options]
addopts=[
"--doctest-modules",
"--doctest-plus",
"--jaxtyping-packages=transformer_lens,beartype.beartype",
"--nbval",
"-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning",
]
testpaths = ["tests", "transformer_lens"] # Only test these directories
doctest_optionflags="NORMALIZE_WHITESPACE ELLIPSIS FLOAT_CMP"
filterwarnings=[
"ignore:pkg_resources is deprecated as an API:DeprecationWarning",
# Ignore numpy.distutils deprecation warning caused by pandas
# More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils
"ignore:distutils Version classes are deprecated:DeprecationWarning",
]
pythonpath=["."]

[tool.isort]
extend_skip=[".venv/", "__init__.py"]
Expand Down
4 changes: 2 additions & 2 deletions tests/acceptance/test_hooked_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,6 @@ def test_input_list_of_strings_mlm(our_bert, huggingface_bert, tokenizer):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device")
def test_cuda(mlm_tokens):
def test_cuda(tokens):
model = HookedEncoder.from_pretrained(MODEL_NAME)
model(mlm_tokens)
model(tokens)
2 changes: 1 addition & 1 deletion tests/acceptance/test_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_cache_device():
torch.device("cuda:1")
)

logits, cache = model.run_with_cache("Hello there", device="cpu")
logits, cache = model.run_with_cache("Hello there", device=torch.device("cpu"))
assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu"))

model.to("cuda")
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/components/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ def test_attention_load_in_4bit():
assert torch.all(attn.b_V == 0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for half/bfloat16 tests")
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_attention_forward_half_precisions(dtype):
# Construct a small attention block
cfg = HookedTransformerConfig(
d_model=64, d_head=16, n_heads=4, n_layers=1, n_ctx=8, dtype=dtype
)
attn = Attention(cfg)
# Random inputs in the matching dtype
batch = 1
seq = 4
x = torch.rand((batch, seq, cfg.d_model), dtype=dtype).to("cuda")
# Run forward through attention (q,k,v = x)
out = attn(x, x, x)
# Should not raise and return a tensor on cuda with same dtype as cfg or compatible
assert isinstance(out, torch.Tensor)
assert out.device.type == "cuda"


def test_attention_config_dict():
cfg = {
"n_layers": 12,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/factored_matrix/test_multiply_by_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
), # Non-scalar Tensor. AssertionError expected.
(torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected.
],
ids=["tensor", "float", "int", "tensor_2d", "tensor_1d"],
)
@pytest.mark.parametrize("leading_dim", [False, True])
@pytest.mark.parametrize("multiply_from_left", [False, True])
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/model_bridge/test_gpt_oss_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_gpt_oss_run_with_cache_with_random_weights():
assert "blocks.0.mlp.hook_router_scores" in cache
assert "blocks.1.mlp.hook_router_scores" in cache

# Router scores should have shape [seq_len, num_experts]
# GPT-OSS has 32 experts
# Router scores should have shape [seq_len, num_experts_per_tok]
# GPT-OSS has 32 experts with top-4 routing, so router_scores is (seq_len, 4)
router_scores_0 = cache["blocks.0.mlp.hook_router_scores"]
assert router_scores_0.shape == (5, 32) # seq_len=5, num_experts=32
assert router_scores_0.shape == (5, 4) # seq_len=5, num_experts_per_tok=4
2 changes: 1 addition & 1 deletion transformer_lens/HookedEncoderDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def generate(
if return_type == "str":
assert self.tokenizer is not None
# Convert tokens to string
return self.tokenizer.decode(decoder_input[0], skip_special_tokens=True)
return cast(str, self.tokenizer.decode(decoder_input[0], skip_special_tokens=True))

else:
return decoder_input
Expand Down
15 changes: 12 additions & 3 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,16 @@ def to_str_tokens(
), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
else:
raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
# In transformers v5, batch_decode treats a flat list as a single sequence,
# not individual token IDs, so would return a single string. To maintain backward
# compatibility with v4, we wrap each token to decode them individually.
if isinstance(tokens, np.ndarray):
tokens_list = [[int(t)] for t in tokens]
else:
tokens_list = [[int(t)] for t in tokens.tolist()]
str_tokens = self.tokenizer.batch_decode(
tokens_list, clean_up_tokenization_spaces=False
)
return str_tokens

def to_single_token(self, string):
Expand Down Expand Up @@ -2085,8 +2094,8 @@ def generate(
output_tokens = sampled_tokens

if return_type == "str":
decoded_texts = [
self.tokenizer.decode(tokens, skip_special_tokens=True)
decoded_texts: List[str] = [
cast(str, self.tokenizer.decode(tokens, skip_special_tokens=True))
for tokens in output_tokens
]
return decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts
Expand Down
25 changes: 18 additions & 7 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
from better_abc import abstract_attribute
from jaxtyping import Float, Int
from torch import Tensor
from transformers.utils.import_utils import is_bitsandbytes_available

from transformer_lens.cache.key_value_cache_entry import (
Expand Down Expand Up @@ -280,8 +281,7 @@ def forward(
raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}")
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
pattern = pattern.to(self.cfg.dtype)
pattern = pattern.to(v.device)
pattern = pattern.to(device=v.device, dtype=v.dtype)
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
if not self.cfg.use_attn_result:
if self.cfg.load_in_4bit:
Expand All @@ -301,15 +301,21 @@ def forward(
self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
)

if self.b_O.device != w.device:
w = w.to(self.b_O.device)
if self.b_O.device != z.device:
z = z.to(self.b_O.device)
# Move output projection weights and bias to the same device as z
# so that the final linear operation occurs on the device of the inputs
if w.device != z.device:
w = w.to(z.device)
b_O: Tensor = self.b_O
if b_O.device != z.device:
b_O = b_O.to(z.device)
# Ensure z has the same dtype as weights used in the output projection
if z.dtype != w.dtype:
z = z.to(w.dtype)

out = F.linear(
z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
w,
self.b_O,
b_O,
)
else:
# Explicitly calculate the attention result so it can be accessed by a hook
Expand All @@ -329,6 +335,11 @@ def forward(
self.W_O,
"head_index d_head d_model -> 1 1 head_index d_head d_model",
)
if w.device != z.device:
w = w.to(z.device)
# Ensure z has the same dtype as w before multiplication
if z.dtype != w.dtype:
z = z.to(w.dtype)
z = einops.rearrange(
z, "batch pos head_index d_head -> batch pos head_index d_head 1"
)
Expand Down
Loading