Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
- name: Authenticate HuggingFace CLI
if: env.HF_TOKEN != ''
run: |
pip install huggingface_hub
pip install huggingface_hub==0.33.0
huggingface-cli login --token "$HF_TOKEN"
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand All @@ -93,11 +93,11 @@ jobs:
- uses: actions/checkout@v3
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "poetry"
- name: Cache Models used with Tests
uses: actions/cache@v3
with:
Expand All @@ -119,7 +119,7 @@ jobs:
- name: Authenticate HuggingFace CLI
if: env.HF_TOKEN != ''
run: |
pip install huggingface_hub
pip install huggingface_hub==0.33.0
huggingface-cli login --token "$HF_TOKEN"
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand Down Expand Up @@ -188,9 +188,10 @@ jobs:
build-docs:
# When running on a PR, this just checks we can build the docs without errors
# When running on merge to main, it builds the docs and then another job deploys them
# Only runs on the original repo, not forks
name: 'Build Docs'
runs-on: ubuntu-latest
if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev') || contains(github.head_ref, 'docs')
if: github.repository == 'TransformerLensOrg/TransformerLens' && (github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev') || contains(github.head_ref, 'docs'))
needs: code-checks
steps:
- uses: actions/checkout@v4
Expand All @@ -216,7 +217,7 @@ jobs:
- name: Authenticate HuggingFace CLI
if: env.HF_TOKEN != ''
run: |
pip install huggingface_hub
pip install huggingface_hub==0.33.0
huggingface-cli login --token "$HF_TOKEN"
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand All @@ -233,8 +234,8 @@ jobs:
deploy-docs:
name: Deploy Docs
runs-on: ubuntu-latest
# Only run if merging a PR into main
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
# Only run if merging a PR into main on the original repo, not forks
if: github.repository == 'TransformerLensOrg/TransformerLens' && github.event_name == 'push' && github.ref == 'refs/heads/main'
needs: build-docs
steps:
- uses: actions/checkout@v4
Expand Down
849 changes: 773 additions & 76 deletions demos/Colab_Compatibility.ipynb

Large diffs are not rendered by default.

979 changes: 500 additions & 479 deletions demos/Interactive_Neuroscope.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/source/content/news/release-2.0.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# TransformerLens 2.0
**May 29, 2024**

I am very happy to announce TransformerLens now has a 2.0 release! If you have been using recent versions of TransformerLens, then the good news is that not much has changed at all. The primary motivation behind this jump is to transition the project to strictly following semantic versioning as described [here](https://semver.org/). At the last minute we did also remove the recently added HookedSAE, so if you had been using that, I would direct you to Joseph Bloom’s [SAELens](http://github.com/jbloomAus/SAELens). Bundled with this major version change are also a handful of internal modifications that only affect contributors.

Expand Down
2,315 changes: 2,221 additions & 94 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
datasets=">=2.7.1"
einops=">=0.6.0"
fancy-einsum=">=0.0.3"
huggingface-hub=">=0.23.2,<1.0"
jaxtyping=">=0.2.11"
numpy=[
{version=">=1.20,<1.25", python=">=3.8,<3.9"},
{version=">=1.24,<2", python=">=3.9,<3.12"},
{version=">=1.26,<2", python=">=3.12,<3.13"},
]
pandas=">=1.1.5"
protobuf=">=3.20.0"
python=">=3.8,<4.0"
rich=">=12.6.0"
sentencepiece="*"
Expand All @@ -41,6 +43,7 @@
[tool.poetry.group.dev.dependencies]
black="^23.3.0"
circuitsvis=">=1.38.1"
gradio=">=4.0.0"
isort="5.8.0"
jupyter=">=1.0.0"
mypy=">=1.10.0"
Expand Down Expand Up @@ -185,4 +188,4 @@
strictDictionaryInference=true
strictListInference=true
strictParameterNoneValue=true
strictSetInference=true
strictSetInference=true
2 changes: 1 addition & 1 deletion tests/acceptance/test_activation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_logit_attrs_works_for_all_input_shapes():
tokens=answer_tokens[:, 0],
incorrect_tokens=answer_tokens[:, 1],
)
assert torch.isclose(ref_logit_diffs, logit_diffs).all()
assert torch.isclose(ref_logit_diffs, logit_diffs, atol=1.1e-7).all()

# Single token
batch = -1
Expand Down
2 changes: 1 addition & 1 deletion tests/acceptance/test_hooked_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_cross_attention(our_model, huggingface_model, hello_world_tokens, decod
huggingface_cross_attn_out = huggingface_cross_attn(
decoder_hidden, key_value_states=encoder_hidden, cache_position=encoder_hidden
)[0]
assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-4, atol=1e-5)
assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-3, atol=1e-4)


def test_cross_attention_layer(our_model, huggingface_model, hello_world_tokens, decoder_input_ids):
Expand Down
43 changes: 31 additions & 12 deletions tests/unit/factored_matrix/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ def factored_matrices_leading_ones(random_matrices_leading_ones):
return [FactoredMatrix(a, b) for a, b in random_matrices_leading_ones]


@pytest.fixture(scope="module")
def random_matrices_bf16():
return [
(randn(3, 2).to(torch.bfloat16), randn(2, 3).to(torch.bfloat16)),
(randn(10, 4).to(torch.bfloat16), randn(4, 10).to(torch.bfloat16)),
]


@pytest.fixture(scope="module")
def factored_matrices_bf16(random_matrices_bf16):
return [FactoredMatrix(a, b) for a, b in random_matrices_bf16]


class TestFactoredMatrixProperties:
def test_AB_property(self, factored_matrices, random_matrices):
for i, factored_matrix in enumerate(factored_matrices):
Expand Down Expand Up @@ -79,18 +92,6 @@ def test_svd_property_leading_ones(self, factored_matrices_leading_ones):
assert torch.allclose(U.mT @ U, torch.eye(U.shape[-1]), atol=1e-5)
assert torch.allclose(Vh.mT @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5)

@pytest.mark.skip(
"""
Jaxtyping throws a TypeError when this test is run.
TypeError: type of the return value must be jaxtyping.Float[Tensor, '*leading_dims mdim']; got torch.Tensor instead
I'm not sure why. The error is not very informative. When debugging the shape was equal to mdim, and *leading_dims should
match zero or more leading dims according to the [docs](https://github.com/google/jaxtyping/blob/main/API.md).
Sort of related to https://github.com/TransformerLensOrg/TransformerLens/issues/190 because jaxtyping
is only enabled at test time and not runtime.
"""
)
def test_eigenvalues_property(self, factored_matrices):
for factored_matrix in factored_matrices:
if factored_matrix.ldim == factored_matrix.rdim:
Expand Down Expand Up @@ -159,3 +160,21 @@ def test_unsqueeze(self, factored_matrices_leading_ones):
assert isinstance(result, FactoredMatrix)
assert torch.allclose(result.A, unsqueezed_A)
assert torch.allclose(result.B, unsqueezed_B)

def test_eigenvalues_bfloat16_support(self, factored_matrices_bf16):
"""
Test that eigenvalues calculation does nott crash for bfloat16 matrices.
"""
for factored_matrix in factored_matrices_bf16:
if factored_matrix.ldim == factored_matrix.rdim:
eigenvalues = factored_matrix.eigenvalues

assert eigenvalues.dtype == torch.complex64

expected_eigenvalues = torch.linalg.eig(
factored_matrix.BA.to(torch.float32)
).eigenvalues

assert torch.allclose(
torch.abs(eigenvalues), torch.abs(expected_eigenvalues), atol=1e-2, rtol=1e-2
)
Loading
Loading