Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions transformer_lens/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,16 @@ def generic_activation_patch(
if index_df is None:
assert index_axis_names is not None

number_of_heads = model.cfg.n_heads
# For some models, the number of key value heads is not the same as the number of attention heads
if activation_name in ["k", "v"] and model.cfg.n_key_value_heads is not None:
number_of_heads = model.cfg.n_key_value_heads

# Get the max range for all possible axes
max_axis_range = {
"layer": model.cfg.n_layers,
"pos": corrupted_tokens.shape[-1],
"head_index": model.cfg.n_heads,
"head_index": number_of_heads,
}
max_axis_range["src_pos"] = max_axis_range["pos"]
max_axis_range["dest_pos"] = max_axis_range["pos"]
Expand Down Expand Up @@ -466,7 +471,7 @@ def layer_head_dest_src_pos_pattern_patch_setter(
index_axis_names=("layer", "pos", "head"),
)
get_act_patch_attn_head_k_by_pos.__doc__ = """
Function to get activation patching results for the keys of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
Function to get activation patching results for the keys of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.

See generic_activation_patch for a more detailed explanation of activation patching

Expand All @@ -477,7 +482,7 @@ def layer_head_dest_src_pos_pattern_patch_setter(
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)

Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.
"""

get_act_patch_attn_head_v_by_pos = partial(
Expand All @@ -487,7 +492,7 @@ def layer_head_dest_src_pos_pattern_patch_setter(
index_axis_names=("layer", "pos", "head"),
)
get_act_patch_attn_head_v_by_pos.__doc__ = """
Function to get activation patching results for the values of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
Function to get activation patching results for the values of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.

See generic_activation_patch for a more detailed explanation of activation patching

Expand All @@ -498,7 +503,7 @@ def layer_head_dest_src_pos_pattern_patch_setter(
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)

Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.
"""
# %%
get_act_patch_attn_head_pattern_by_pos = partial(
Expand Down Expand Up @@ -593,7 +598,7 @@ def layer_head_dest_src_pos_pattern_patch_setter(
index_axis_names=("layer", "head"),
)
get_act_patch_attn_head_k_all_pos.__doc__ = """
Function to get activation patching results for the keys of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
Function to get activation patching results for the keys of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.

See generic_activation_patch for a more detailed explanation of activation patching

Expand All @@ -604,7 +609,7 @@ def layer_head_dest_src_pos_pattern_patch_setter(
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)

Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.
"""

get_act_patch_attn_head_v_all_pos = partial(
Expand All @@ -614,7 +619,7 @@ def layer_head_dest_src_pos_pattern_patch_setter(
index_axis_names=("layer", "head"),
)
get_act_patch_attn_head_v_all_pos.__doc__ = """
Function to get activation patching results for the values of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
Function to get activation patching results for the values of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.

See generic_activation_patch for a more detailed explanation of activation patching

Expand All @@ -625,7 +630,7 @@ def layer_head_dest_src_pos_pattern_patch_setter(
patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)

Returns:
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.
"""

get_act_patch_attn_head_pattern_all_pos = partial(
Expand Down Expand Up @@ -673,12 +678,17 @@ def get_act_patch_attn_head_all_pos_every(
act_patch_results.append(
get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric)
)

# Reshape k and v to be compatible with the rest of the results in case of n_key_value_heads != n_heads
k_results = get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric)
act_patch_results.append(
get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric)
torch.nn.functional.pad(k_results, (0, act_patch_results[-1].size(-1) - k_results.size(-1)))
)
v_results = get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric)
act_patch_results.append(
get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric)
torch.nn.functional.pad(v_results, (0, act_patch_results[-1].size(-1) - v_results.size(-1)))
)

act_patch_results.append(
get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric)
)
Expand Down Expand Up @@ -706,11 +716,15 @@ def get_act_patch_attn_head_by_pos_every(
act_patch_results.append(
get_act_patch_attn_head_q_by_pos(model, corrupted_tokens, clean_cache, metric)
)

# Reshape k and v to be compatible with the rest of the results in case of n_key_value_heads != n_heads
k_results = get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric)
act_patch_results.append(
get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric)
torch.nn.functional.pad(k_results, (0, act_patch_results[-1].size(-1) - k_results.size(-1)))
)
v_results = get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric)
act_patch_results.append(
get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric)
torch.nn.functional.pad(v_results, (0, act_patch_results[-1].size(-1) - v_results.size(-1)))
)

# Reshape pattern to be compatible with the rest of the results
Expand Down
Loading