Skip to content

Commit

Permalink
Merge branch 'misc-code-cleanup' of https://github.com/understanding-…
Browse files Browse the repository at this point in the history
…search/maze-transformer into misc-code-cleanup
  • Loading branch information
mivanit committed Jan 28, 2024
2 parents 8d3916a + b799a92 commit 4f82556
Show file tree
Hide file tree
Showing 16 changed files with 2,687 additions and 335 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ Most of the functionality is demonstrated in the ipython notebooks in the `noteb
* Restart VSCode
* In VSCode, select the python interpreter located in `maze-transformer/.venv/bin` as your juptyer kernel

## Instructions for Conda users

* Create a new Conda environment: `conda create -n mazetransformer python=3.10 poetry`
* Activate the environment: `conda activate mazetransformer`
* Update poetry and install dev dependencies
```
poetry self update
poetry config virtualenvs.in-project true
poetry install --with dev
```
* Run unit, integration, and notebook tests
```
make test
```

## Testing & Static analysis

Expand Down
3 changes: 1 addition & 2 deletions maze_transformer/evaluation/path_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def __call__(
maze: LatticeMaze | None = None,
solution: CoordArray | None = None,
prediction: CoordArray | None = None,
) -> float:
...
) -> float: ...


def path_as_segments_iter(path: CoordArray) -> typing.Iterable[tuple]:
Expand Down
19 changes: 10 additions & 9 deletions maze_transformer/mechinterp/direct_logit_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,23 @@ def plot_direct_logit_attribution(
answer_tokens: Int[torch.Tensor, "n_mazes"],
do_neurons: bool = False,
show: bool = True,
layer_index_normalization: typing.Callable[[float, int], float]
| None = lambda contrib, layer_idx: contrib,
layer_index_normalization: (
typing.Callable[[float, int], float] | None
) = lambda contrib, layer_idx: contrib,
) -> tuple[plt.Figure, plt.Axes, dict[str, Float[np.ndarray, "layer head/neuron"]]]:
"""compute, process, and plot direct logit attribution
Layer index normalization allows us to process the contribution according to the layer index.
by default, its the identity map for contribs:
`layer_index_normalization: typing.Callable[[float, int], float]|None = lambda contrib, layer_idx: contrib`
"""
dla_data: dict[
str, Float[np.ndarray, "layer head/neuron"]
] = compute_direct_logit_attribution(
model=model,
cache=cache,
answer_tokens=answer_tokens,
do_neurons=do_neurons,
dla_data: dict[str, Float[np.ndarray, "layer head/neuron"]] = (
compute_direct_logit_attribution(
model=model,
cache=cache,
answer_tokens=answer_tokens,
do_neurons=do_neurons,
)
)
if layer_index_normalization is not None:
dla_data = {
Expand Down
6 changes: 2 additions & 4 deletions maze_transformer/mechinterp/logit_attrib_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def get_token_first_index(search_token: str, token_list: list[str]) -> int:
class DLAProtocol(typing.Protocol):
"""should take a dataset's tokens, and return a tuple of (prompts, targets)"""

def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup:
...
def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup: ...


class DLAProtocolFixed(typing.Protocol):
Expand All @@ -32,8 +31,7 @@ class DLAProtocolFixed(typing.Protocol):
this variant signifies it's ready to be used -- no keyword arguments are needed
"""

def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup:
...
def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup: ...


def token_after_fixed_start_token(
Expand Down
18 changes: 9 additions & 9 deletions maze_transformer/mechinterp/logit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def logit_diff_residual_stream(
vocab_tensor: Float[torch.Tensor, "d_vocab"] = torch.arange(
d_vocab, dtype=torch.long
)
vocab_residual_directions: Float[
torch.Tensor, "d_vocab d_model"
] = model.tokens_to_residual_directions(vocab_tensor)
vocab_residual_directions: Float[torch.Tensor, "d_vocab d_model"] = (
model.tokens_to_residual_directions(vocab_tensor)
)
# get embedding of answer tokens
answer_residual_directions = vocab_residual_directions[tokens_correct]
# get the directional difference between logits and corrent and logits on {all other tokens, comparison tokens}
Expand All @@ -108,12 +108,12 @@ def logit_diff_residual_stream(
][:, -1, :]

# scaling the values in residual stream with layer norm
scaled_final_token_residual_stream: Float[
torch.Tensor, "samples d_model"
] = cache.apply_ln_to_stack(
final_token_residual_stream,
layer=-1,
pos_slice=-1,
scaled_final_token_residual_stream: Float[torch.Tensor, "samples d_model"] = (
cache.apply_ln_to_stack(
final_token_residual_stream,
layer=-1,
pos_slice=-1,
)
)

# measure similarity between the logit diff directions and the residual stream at final layer directions
Expand Down
6 changes: 3 additions & 3 deletions maze_transformer/mechinterp/plot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ def mazeplot_attention(
node_values=node_values,
color_map=cmap,
target_token_coord=target_coord,
preceeding_tokens_coords=[final_prompt_coord]
if final_prompt_coord is not None
else None,
preceeding_tokens_coords=(
[final_prompt_coord] if final_prompt_coord is not None else None
),
colormap_center=colormap_center_val,
colormap_max=colormap_max,
hide_colorbar=hide_colorbar,
Expand Down
14 changes: 8 additions & 6 deletions maze_transformer/mechinterp/residual_stream_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def process_tokens_for_pca(tokenizer: MazeTokenizer) -> list[TokenPlottingInfo]:
tokenizer.token_arr,
tokens_coords,
[
coordinate_to_color(coord, max_val=max_coord)
if isinstance(coord, tuple)
else (0.0, 1.0, 0.0)
(
coordinate_to_color(coord, max_val=max_coord)
if isinstance(coord, tuple)
else (0.0, 1.0, 0.0)
)
for coord in tokens_coords
],
)
Expand Down Expand Up @@ -249,9 +251,9 @@ def compute_distances_and_correlation(
# embedding_distances /= embedding_distances.max()

# Convert the distances to a square matrix
embedding_distances_matrix: Float[
np.ndarray, "n_coord_tokens n_coord_tokens"
] = squareform(embedding_distances)
embedding_distances_matrix: Float[np.ndarray, "n_coord_tokens n_coord_tokens"] = (
squareform(embedding_distances)
)

# Calculate the correlation between the embedding and coordinate distances
coordinate_coordinates: Float[np.ndarray, "n_coord_tokens 2"] = np.array(
Expand Down
27 changes: 14 additions & 13 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def get_intervals(
)

except ValueError as e:
_debug_vals: str = f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}"
_debug_vals: str = (
f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}"
)
raise ValueError(f"{_debug_vals}\ntriggered error:\n{e}") from e

# disable if set to 0 or negative
Expand All @@ -235,9 +237,9 @@ def get_intervals(
# actually return the intervals
if mod_batch_size:
return {
k: max(1, v // self.batch_size)
if isinstance(v, int)
else v # if float, leave it as is since its float("inf")
k: (
max(1, v // self.batch_size) if isinstance(v, int) else v
) # if float, leave it as is since its float("inf")
for k, v in intervals_new.items()
}
else:
Expand Down Expand Up @@ -459,9 +461,11 @@ def summary(self) -> str:
"model_cfg": self.model_cfg.summary(),
"train_cfg": self.train_cfg.summary(),
"pretrainedtokenizer_kwargs": self.pretrainedtokenizer_kwargs,
"maze_tokenizer": self.maze_tokenizer.summary()
if self.maze_tokenizer is not None
else None,
"maze_tokenizer": (
self.maze_tokenizer.summary()
if self.maze_tokenizer is not None
else None
),
}

@property
Expand Down Expand Up @@ -655,12 +659,9 @@ def _load_state_dict_wrapper(
self.zanj_model_config.model_cfg.weight_processing["are_layernorms_folded"]
or fold_ln
)
self.zanj_model_config.model_cfg.weight_processing[
"are_weights_processed"
] = self.zanj_model_config.model_cfg.weight_processing[
"are_weights_processed"
] or (
not recover_exact
self.zanj_model_config.model_cfg.weight_processing["are_weights_processed"] = (
self.zanj_model_config.model_cfg.weight_processing["are_weights_processed"]
or (not recover_exact)
)

self.load_and_process_state_dict(
Expand Down
74 changes: 46 additions & 28 deletions maze_transformer/training/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from muutils.json_serialize import SerializableDataclass, serializable_dataclass
from muutils.mlutils import get_device
from muutils.mlutils import get_device, pprint_summary
from torch.utils.data import DataLoader

from maze_transformer.training.config import (
Expand Down Expand Up @@ -36,7 +36,7 @@ def __str__(self):

def train_model(
base_path: str | Path,
wandb_project: Union[WandbProject, str],
wandb_project: Union[WandbProject, str] | None,
cfg: ConfigHolder | None = None,
cfg_file: str | Path | None = None,
cfg_names: typing.Sequence[str] | None = None,
Expand All @@ -59,6 +59,8 @@ def train_model(
- model config names: {model_cfg_names}
- train config names: {train_cfg_names}
"""
USES_LOGGER: bool = wandb_project is not None

if help:
print(train_model.__doc__)
return
Expand All @@ -84,26 +86,43 @@ def train_model(
(output_path / TRAIN_SAVE_FILES.checkpoints).mkdir(parents=True)

# set up logger
logger: WandbLogger = WandbLogger.create(
config=cfg.serialize(),
project=wandb_project,
job_type=WandbJobType.TRAIN_MODEL,
logger_cfg_dict = dict(
logger_cfg={
"output_dir": output_path.as_posix(),
"cfg.name": cfg.name,
"data_cfg.name": cfg.dataset_cfg.name,
"train_cfg.name": cfg.train_cfg.name,
"model_cfg.name": cfg.model_cfg.name,
"cfg_summary": cfg.summary(),
"cfg": cfg.serialize(),
},
)
logger.progress("Initialized logger")
logger.summary(
dict(
logger_cfg={
"output_dir": output_path.as_posix(),
"cfg.name": cfg.name,
"data_cfg.name": cfg.dataset_cfg.name,
"train_cfg.name": cfg.train_cfg.name,
"model_cfg.name": cfg.model_cfg.name,
"cfg_summary": cfg.summary(),
"cfg": cfg.serialize(),
},

# Set up logger if wanb project is specified
if USES_LOGGER:
logger: WandbLogger = WandbLogger.create(
config=cfg.serialize(),
project=wandb_project,
job_type=WandbJobType.TRAIN_MODEL,
)
)
logger.progress("Summary logged, getting dataset")
logger.progress("Initialized logger")
else:
logger = None

def log(msg: str | dict, log_type: str = "progress", **kwargs):
# Convenience function to let training routine work whether or not
# logger exists
if logger:
log_fn = getattr(logger, log_type)
log_fn(msg, **kwargs)
else:
if type(msg) == dict:
pprint_summary(msg)
else:
print(msg)

log(logger_cfg_dict, log_type="summary")
log("Summary logged, getting dataset")

# load dataset
if dataset is None:
Expand All @@ -115,18 +134,19 @@ def train_model(
)
else:
if dataset.cfg == cfg.dataset_cfg:
logger.progress(f"passed dataset has matching config, using that")
log(f"passed dataset has matching config, using that")
else:
if allow_dataset_override:
logger.progress(
log(
f"passed dataset has different config than cfg.dataset_cfg, but allow_dataset_override is True, so using passed dataset"
)
else:
raise ValueError(
f"dataset has different config than cfg.dataset_cfg, and allow_dataset_override is False"
)

logger.progress(f"finished getting training dataset with {len(dataset)} samples")
log(f"finished getting training dataset with {len(dataset)} samples")

# validation dataset, if applicable
val_dataset: MazeDataset | None = None
if cfg.train_cfg.validation_dataset_cfg is not None:
Expand All @@ -148,7 +168,7 @@ def train_model(
dataset.mazes = dataset.mazes[: split_dataset_sizes[0]]
dataset.update_self_config()
val_dataset.update_self_config()
logger.progress(
log(
f"got validation dataset by splitting training dataset into {len(dataset)} train and {len(val_dataset)} validation samples"
)
elif isinstance(cfg.train_cfg.validation_dataset_cfg, MazeDatasetConfig):
Expand All @@ -158,14 +178,12 @@ def train_model(
local_base_path=base_path,
verbose=dataset_verbose,
)
logger.progress(
f"got custom validation dataset with {len(val_dataset)} samples"
)
log(f"got custom validation dataset with {len(val_dataset)} samples")

# get dataloader and then train
dataloader: DataLoader = get_dataloader(dataset, cfg, logger)

logger.progress("finished dataloader, passing to train()")
log("finished dataloader, passing to train()")
trained_model: ZanjHookedTransformer = train(
cfg=cfg,
dataloader=dataloader,
Expand Down
10 changes: 4 additions & 6 deletions maze_transformer/training/train_save_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ class TRAIN_SAVE_FILES:
config_holder: str = "config.json"
checkpoints: str = "checkpoints"
log: str = "log.jsonl"
model_checkpt_zanj: Callable[
[int], str
] = lambda iteration: f"model.iter_{iteration}.zanj"
model_checkpt_zanj: Callable[[int], str] = (
lambda iteration: f"model.iter_{iteration}.zanj"
)
model_final_zanj: str = "model.final.zanj"
model_run_dir: Callable[
[ConfigHolder], str
] = (
model_run_dir: Callable[[ConfigHolder], str] = (
lambda cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
)
Loading

0 comments on commit 4f82556

Please sign in to comment.