Skip to content

Commit

Permalink
Merge pull request #206 from naveenarun/naveenarun-205-black-v24
Browse files Browse the repository at this point in the history
merging to integration branch
  • Loading branch information
mivanit authored Jan 28, 2024
2 parents d9a849c + 08f60a1 commit 2a42383
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 55 deletions.
3 changes: 1 addition & 2 deletions maze_transformer/evaluation/path_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def __call__(
maze: LatticeMaze | None = None,
solution: CoordArray | None = None,
prediction: CoordArray | None = None,
) -> float:
...
) -> float: ...


def path_as_segments_iter(path: CoordArray) -> typing.Iterable[tuple]:
Expand Down
19 changes: 10 additions & 9 deletions maze_transformer/mechinterp/direct_logit_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,23 @@ def plot_direct_logit_attribution(
answer_tokens: Int[torch.Tensor, "n_mazes"],
do_neurons: bool = False,
show: bool = True,
layer_index_normalization: typing.Callable[[float, int], float]
| None = lambda contrib, layer_idx: contrib,
layer_index_normalization: (
typing.Callable[[float, int], float] | None
) = lambda contrib, layer_idx: contrib,
) -> tuple[plt.Figure, plt.Axes, dict[str, Float[np.ndarray, "layer head/neuron"]]]:
"""compute, process, and plot direct logit attribution
Layer index normalization allows us to process the contribution according to the layer index.
by default, its the identity map for contribs:
`layer_index_normalization: typing.Callable[[float, int], float]|None = lambda contrib, layer_idx: contrib`
"""
dla_data: dict[
str, Float[np.ndarray, "layer head/neuron"]
] = compute_direct_logit_attribution(
model=model,
cache=cache,
answer_tokens=answer_tokens,
do_neurons=do_neurons,
dla_data: dict[str, Float[np.ndarray, "layer head/neuron"]] = (
compute_direct_logit_attribution(
model=model,
cache=cache,
answer_tokens=answer_tokens,
do_neurons=do_neurons,
)
)
if layer_index_normalization is not None:
dla_data = {
Expand Down
6 changes: 2 additions & 4 deletions maze_transformer/mechinterp/logit_attrib_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def get_token_first_index(search_token: str, token_list: list[str]) -> int:
class DLAProtocol(typing.Protocol):
"""should take a dataset's tokens, and return a tuple of (prompts, targets)"""

def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup:
...
def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup: ...


class DLAProtocolFixed(typing.Protocol):
Expand All @@ -32,8 +31,7 @@ class DLAProtocolFixed(typing.Protocol):
this variant signifies it's ready to be used -- no keyword arguments are needed
"""

def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup:
...
def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup: ...


def token_after_fixed_start_token(
Expand Down
18 changes: 9 additions & 9 deletions maze_transformer/mechinterp/logit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def logit_diff_residual_stream(
vocab_tensor: Float[torch.Tensor, "d_vocab"] = torch.arange(
d_vocab, dtype=torch.long
)
vocab_residual_directions: Float[
torch.Tensor, "d_vocab d_model"
] = model.tokens_to_residual_directions(vocab_tensor)
vocab_residual_directions: Float[torch.Tensor, "d_vocab d_model"] = (
model.tokens_to_residual_directions(vocab_tensor)
)
# get embedding of answer tokens
answer_residual_directions = vocab_residual_directions[tokens_correct]
# get the directional difference between logits and corrent and logits on {all other tokens, comparison tokens}
Expand All @@ -108,12 +108,12 @@ def logit_diff_residual_stream(
][:, -1, :]

# scaling the values in residual stream with layer norm
scaled_final_token_residual_stream: Float[
torch.Tensor, "samples d_model"
] = cache.apply_ln_to_stack(
final_token_residual_stream,
layer=-1,
pos_slice=-1,
scaled_final_token_residual_stream: Float[torch.Tensor, "samples d_model"] = (
cache.apply_ln_to_stack(
final_token_residual_stream,
layer=-1,
pos_slice=-1,
)
)

# measure similarity between the logit diff directions and the residual stream at final layer directions
Expand Down
6 changes: 3 additions & 3 deletions maze_transformer/mechinterp/plot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ def mazeplot_attention(
node_values=node_values,
color_map=cmap,
target_token_coord=target_coord,
preceeding_tokens_coords=[final_prompt_coord]
if final_prompt_coord is not None
else None,
preceeding_tokens_coords=(
[final_prompt_coord] if final_prompt_coord is not None else None
),
colormap_center=colormap_center_val,
colormap_max=colormap_max,
hide_colorbar=hide_colorbar,
Expand Down
14 changes: 8 additions & 6 deletions maze_transformer/mechinterp/residual_stream_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def process_tokens_for_pca(tokenizer: MazeTokenizer) -> list[TokenPlottingInfo]:
tokenizer.token_arr,
tokens_coords,
[
coordinate_to_color(coord, max_val=max_coord)
if isinstance(coord, tuple)
else (0.0, 1.0, 0.0)
(
coordinate_to_color(coord, max_val=max_coord)
if isinstance(coord, tuple)
else (0.0, 1.0, 0.0)
)
for coord in tokens_coords
],
)
Expand Down Expand Up @@ -249,9 +251,9 @@ def compute_distances_and_correlation(
# embedding_distances /= embedding_distances.max()

# Convert the distances to a square matrix
embedding_distances_matrix: Float[
np.ndarray, "n_coord_tokens n_coord_tokens"
] = squareform(embedding_distances)
embedding_distances_matrix: Float[np.ndarray, "n_coord_tokens n_coord_tokens"] = (
squareform(embedding_distances)
)

# Calculate the correlation between the embedding and coordinate distances
coordinate_coordinates: Float[np.ndarray, "n_coord_tokens 2"] = np.array(
Expand Down
27 changes: 14 additions & 13 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def get_intervals(
)

except ValueError as e:
_debug_vals: str = f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}"
_debug_vals: str = (
f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}"
)
raise ValueError(f"{_debug_vals}\ntriggered error:\n{e}") from e

# disable if set to 0 or negative
Expand All @@ -235,9 +237,9 @@ def get_intervals(
# actually return the intervals
if mod_batch_size:
return {
k: max(1, v // self.batch_size)
if isinstance(v, int)
else v # if float, leave it as is since its float("inf")
k: (
max(1, v // self.batch_size) if isinstance(v, int) else v
) # if float, leave it as is since its float("inf")
for k, v in intervals_new.items()
}
else:
Expand Down Expand Up @@ -459,9 +461,11 @@ def summary(self) -> str:
"model_cfg": self.model_cfg.summary(),
"train_cfg": self.train_cfg.summary(),
"pretrainedtokenizer_kwargs": self.pretrainedtokenizer_kwargs,
"maze_tokenizer": self.maze_tokenizer.summary()
if self.maze_tokenizer is not None
else None,
"maze_tokenizer": (
self.maze_tokenizer.summary()
if self.maze_tokenizer is not None
else None
),
}

@property
Expand Down Expand Up @@ -655,12 +659,9 @@ def _load_state_dict_wrapper(
self.zanj_model_config.model_cfg.weight_processing["are_layernorms_folded"]
or fold_ln
)
self.zanj_model_config.model_cfg.weight_processing[
"are_weights_processed"
] = self.zanj_model_config.model_cfg.weight_processing[
"are_weights_processed"
] or (
not recover_exact
self.zanj_model_config.model_cfg.weight_processing["are_weights_processed"] = (
self.zanj_model_config.model_cfg.weight_processing["are_weights_processed"]
or (not recover_exact)
)

self.load_and_process_state_dict(
Expand Down
10 changes: 4 additions & 6 deletions maze_transformer/training/train_save_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ class TRAIN_SAVE_FILES:
config_holder: str = "config.json"
checkpoints: str = "checkpoints"
log: str = "log.jsonl"
model_checkpt_zanj: Callable[
[int], str
] = lambda iteration: f"model.iter_{iteration}.zanj"
model_checkpt_zanj: Callable[[int], str] = (
lambda iteration: f"model.iter_{iteration}.zanj"
)
model_final_zanj: str = "model.final.zanj"
model_run_dir: Callable[
[ConfigHolder], str
] = (
model_run_dir: Callable[[ConfigHolder], str] = (
lambda cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
)
1 change: 1 addition & 0 deletions tests/integration/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
a HookedTransformer with folding etc., as they would be from
just applying the model to the input
"""

import warnings
from pathlib import Path

Expand Down
7 changes: 4 additions & 3 deletions tests/unit/maze_transformer/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
We may want a separate set of tests for different tokenization schemes
"""

from itertools import product

import torch
Expand Down Expand Up @@ -81,11 +82,11 @@ def test_tokenization_encoding(
)
def test_to_ascii(tok_mode):
# Check that the ascii encoding works for multiple different inputs
maze_str_tokens: list[
str
] = """<ADJLIST_START> (1,1) <--> (2,1) ; (2,0) <--> (1,0) ; (0,1) <--> (0,0) ;
maze_str_tokens: list[str] = (
"""<ADJLIST_START> (1,1) <--> (2,1) ; (2,0) <--> (1,0) ; (0,1) <--> (0,0) ;
(2,2) <--> (2,1) ; (2,0) <--> (2,1) ; (0,2) <--> (1,2) ; (0,0) <--> (1,0) ; (0,2) <--> (0,1) ;
<ADJLIST_END> <ORIGIN_START> (0,0) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (0,0) (1,0) (2,0) (2,1) <PATH_END>""".split()
)

target: list[str] = [
"#######",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
test loading of old style models
"""

import json

import pytest
Expand Down

0 comments on commit 2a42383

Please sign in to comment.