Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,401 changes: 2,523 additions & 1,878 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 5 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@
{version=">=1.26,<2", python=">=3.12,<3.13"},
]
pandas=">=1.1.5"
protobuf="3.20.1"
python=">=3.8,<4.0"
rich=">=12.6.0"
sentencepiece="*"
torch=">=2.2"
torch=[{version="<2.6", python=">=3.8,<3.9"}, {version=">=2.6", python=">=3.9"}]
tqdm=">=4.64.1"
transformers=">=4.43"
transformers-stream-generator="^0.0.5"
typeguard="^4.2"
typing-extensions="*"
wandb=">=0.13.5"
typeguard = "^4.2"
transformers-stream-generator = "^0.0.5"

[tool.poetry.group]
[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -128,10 +129,7 @@
# All rules apart from base are shown explicitly below
deprecateTypingAliases=true
disableBytesTypePromotions=true
exclude = [
"*/**/*.py",
"!/transformer_lens/hook_points.py"
]
exclude=["!/transformer_lens/hook_points.py", "*/**/*.py"]
reportAssertAlwaysTrue=true
reportConstantRedefinition=true
reportDeprecated=true
Expand Down Expand Up @@ -187,4 +185,3 @@
strictListInference=true
strictParameterNoneValue=true
strictSetInference=true

103 changes: 71 additions & 32 deletions transformer_lens/HookedEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@

import logging
import os
from typing import Dict, List, Optional, Tuple, TypeVar, Union, cast, overload
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload

import torch
import torch.nn as nn
from einops import repeat
from jaxtyping import Float, Int
from torch import nn
from transformers import AutoTokenizer
from transformers.models.auto.tokenization_auto import AutoTokenizer
from typing_extensions import Literal

import transformer_lens.loading_from_pretrained as loading
from transformer_lens.ActivationCache import ActivationCache
from transformer_lens.components import (
MLP,
Attention,
BertBlock,
BertEmbed,
BertMLMHead,
Expand All @@ -46,7 +48,13 @@ class HookedEncoder(HookedRootModule):
- There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model
"""

def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
def __init__(
self,
cfg: Union[HookedTransformerConfig, Dict],
tokenizer: Optional[Any] = None,
move_to_device: bool = True,
**kwargs: Any,
):
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig(**cfg)
Expand Down Expand Up @@ -85,6 +93,8 @@ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
self.hook_full_embed = HookPoint()

if move_to_device:
if self.cfg.device is None:
raise ValueError("Cannot move to device when device is None")
self.to(self.cfg.device)

self.setup()
Expand Down Expand Up @@ -121,11 +131,13 @@ def to_tokens(
)

tokens = encodings.input_ids
token_type_ids = encodings.token_type_ids
attention_mask = encodings.attention_mask

if move_to_device:
tokens = tokens.to(self.cfg.device)
token_type_ids = encodings.token_type_ids.to(self.cfg.device)
attention_mask = encodings.attention_mask.to(self.cfg.device)
token_type_ids = token_type_ids.to(self.cfg.device)
attention_mask = attention_mask.to(self.cfg.device)

return tokens, token_type_ids, attention_mask

Expand Down Expand Up @@ -188,7 +200,7 @@ def forward(
return_type: Union[Literal["logits"], Literal["predictions"]],
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
) -> Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str],]:
) -> Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]:
...

@overload
Expand All @@ -202,7 +214,7 @@ def forward(
return_type: Literal[None],
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str],]]:
) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]:
...

def forward(
Expand All @@ -215,7 +227,7 @@ def forward(
return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits",
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str],]]:
) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]:
"""Forward pass through the HookedEncoder. Performs Masked Language Modelling on the given input.

Args:
Expand Down Expand Up @@ -277,6 +289,9 @@ def forward(
logits = self.unembed(resid)

if return_type == "predictions":
assert (
self.tokenizer is not None
), "Must have a tokenizer to use return_type='predictions'"
# Get predictions for masked tokens
logprobs = logits[tokens == self.tokenizer.mask_token_id].log_softmax(dim=-1)
predictions = self.tokenizer.decode(logprobs.argmax(dim=-1))
Expand All @@ -295,22 +310,22 @@ def forward(

@overload
def run_with_cache(
self, *model_args, return_cache_object: Literal[True] = True, **kwargs
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache,]:
self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]:
...

@overload
def run_with_cache(
self, *model_args, return_cache_object: Literal[False], **kwargs
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor],]:
self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any
) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
...

def run_with_cache(
self,
*model_args,
*model_args: Any,
return_cache_object: bool = True,
remove_batch_dim: bool = False,
**kwargs,
**kwargs: Any,
) -> Tuple[
Float[torch.Tensor, "batch pos d_vocab"],
Union[ActivationCache, Dict[str, torch.Tensor]],
Expand Down Expand Up @@ -354,12 +369,12 @@ def from_pretrained(
model_name: str,
checkpoint_index: Optional[int] = None,
checkpoint_value: Optional[int] = None,
hf_model=None,
hf_model: Optional[Any] = None,
device: Optional[str] = None,
tokenizer=None,
move_to_device=True,
dtype=torch.float32,
**from_pretrained_kwargs,
tokenizer: Optional[Any] = None,
move_to_device: bool = True,
dtype: torch.dtype = torch.float32,
**from_pretrained_kwargs: Any,
) -> HookedEncoder:
"""Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model."""
logging.warning(
Expand Down Expand Up @@ -447,62 +462,86 @@ def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
@property
def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the key weights across all layers"""
return torch.stack([cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.W_K for block in self.blocks], dim=0)

@property
def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the query weights across all layers"""
return torch.stack([cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.W_Q for block in self.blocks], dim=0)

@property
def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
"""Stacks the value weights across all layers"""
return torch.stack([cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.W_V for block in self.blocks], dim=0)

@property
def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
"""Stacks the attn output weights across all layers"""
return torch.stack([cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.W_O for block in self.blocks], dim=0)

@property
def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
"""Stacks the MLP input weights across all layers"""
return torch.stack([cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.mlp, MLP)
return torch.stack([block.mlp.W_in for block in self.blocks], dim=0)

@property
def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
"""Stacks the MLP output weights across all layers"""
return torch.stack([cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.mlp, MLP)
return torch.stack([block.mlp.W_out for block in self.blocks], dim=0)

@property
def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the key biases across all layers"""
return torch.stack([cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.b_K for block in self.blocks], dim=0)

@property
def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the query biases across all layers"""
return torch.stack([cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.b_Q for block in self.blocks], dim=0)

@property
def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
"""Stacks the value biases across all layers"""
return torch.stack([cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.b_V for block in self.blocks], dim=0)

@property
def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
"""Stacks the attn output biases across all layers"""
return torch.stack([cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.attn, Attention)
return torch.stack([block.attn.b_O for block in self.blocks], dim=0)

@property
def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
"""Stacks the MLP input biases across all layers"""
return torch.stack([cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.mlp, MLP)
return torch.stack([block.mlp.b_in for block in self.blocks], dim=0)

@property
def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
"""Stacks the MLP output biases across all layers"""
return torch.stack([cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0)
for block in self.blocks:
assert isinstance(block.mlp, MLP)
return torch.stack([block.mlp.b_out for block in self.blocks], dim=0)

@property
def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
Expand Down
Loading