Skip to content

Commit

Permalink
Merge pull request #204 from naveenarun/naveenarun-train-nowandb
Browse files Browse the repository at this point in the history
Allow training notebook to run without logger or wandb config
  • Loading branch information
mivanit authored Jan 28, 2024
2 parents 2a42383 + 4d56e31 commit 70d494c
Show file tree
Hide file tree
Showing 4 changed files with 2,616 additions and 280 deletions.
74 changes: 46 additions & 28 deletions maze_transformer/training/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from muutils.json_serialize import SerializableDataclass, serializable_dataclass
from muutils.mlutils import get_device
from muutils.mlutils import get_device, pprint_summary
from torch.utils.data import DataLoader

from maze_transformer.training.config import (
Expand Down Expand Up @@ -36,7 +36,7 @@ def __str__(self):

def train_model(
base_path: str | Path,
wandb_project: Union[WandbProject, str],
wandb_project: Union[WandbProject, str] | None,
cfg: ConfigHolder | None = None,
cfg_file: str | Path | None = None,
cfg_names: typing.Sequence[str] | None = None,
Expand All @@ -59,6 +59,8 @@ def train_model(
- model config names: {model_cfg_names}
- train config names: {train_cfg_names}
"""
USES_LOGGER: bool = wandb_project is not None

if help:
print(train_model.__doc__)
return
Expand All @@ -84,26 +86,43 @@ def train_model(
(output_path / TRAIN_SAVE_FILES.checkpoints).mkdir(parents=True)

# set up logger
logger: WandbLogger = WandbLogger.create(
config=cfg.serialize(),
project=wandb_project,
job_type=WandbJobType.TRAIN_MODEL,
logger_cfg_dict = dict(
logger_cfg={
"output_dir": output_path.as_posix(),
"cfg.name": cfg.name,
"data_cfg.name": cfg.dataset_cfg.name,
"train_cfg.name": cfg.train_cfg.name,
"model_cfg.name": cfg.model_cfg.name,
"cfg_summary": cfg.summary(),
"cfg": cfg.serialize(),
},
)
logger.progress("Initialized logger")
logger.summary(
dict(
logger_cfg={
"output_dir": output_path.as_posix(),
"cfg.name": cfg.name,
"data_cfg.name": cfg.dataset_cfg.name,
"train_cfg.name": cfg.train_cfg.name,
"model_cfg.name": cfg.model_cfg.name,
"cfg_summary": cfg.summary(),
"cfg": cfg.serialize(),
},

# Set up logger if wanb project is specified
if USES_LOGGER:
logger: WandbLogger = WandbLogger.create(
config=cfg.serialize(),
project=wandb_project,
job_type=WandbJobType.TRAIN_MODEL,
)
)
logger.progress("Summary logged, getting dataset")
logger.progress("Initialized logger")
else:
logger = None

def log(msg: str | dict, log_type: str = "progress", **kwargs):
# Convenience function to let training routine work whether or not
# logger exists
if logger:
log_fn = getattr(logger, log_type)
log_fn(msg, **kwargs)
else:
if type(msg) == dict:
pprint_summary(msg)
else:
print(msg)

log(logger_cfg_dict, log_type="summary")
log("Summary logged, getting dataset")

# load dataset
if dataset is None:
Expand All @@ -115,18 +134,19 @@ def train_model(
)
else:
if dataset.cfg == cfg.dataset_cfg:
logger.progress(f"passed dataset has matching config, using that")
log(f"passed dataset has matching config, using that")
else:
if allow_dataset_override:
logger.progress(
log(
f"passed dataset has different config than cfg.dataset_cfg, but allow_dataset_override is True, so using passed dataset"
)
else:
raise ValueError(
f"dataset has different config than cfg.dataset_cfg, and allow_dataset_override is False"
)

logger.progress(f"finished getting training dataset with {len(dataset)} samples")
log(f"finished getting training dataset with {len(dataset)} samples")

# validation dataset, if applicable
val_dataset: MazeDataset | None = None
if cfg.train_cfg.validation_dataset_cfg is not None:
Expand All @@ -148,7 +168,7 @@ def train_model(
dataset.mazes = dataset.mazes[: split_dataset_sizes[0]]
dataset.update_self_config()
val_dataset.update_self_config()
logger.progress(
log(
f"got validation dataset by splitting training dataset into {len(dataset)} train and {len(val_dataset)} validation samples"
)
elif isinstance(cfg.train_cfg.validation_dataset_cfg, MazeDatasetConfig):
Expand All @@ -158,14 +178,12 @@ def train_model(
local_base_path=base_path,
verbose=dataset_verbose,
)
logger.progress(
f"got custom validation dataset with {len(val_dataset)} samples"
)
log(f"got custom validation dataset with {len(val_dataset)} samples")

# get dataloader and then train
dataloader: DataLoader = get_dataloader(dataset, cfg, logger)

logger.progress("finished dataloader, passing to train()")
log("finished dataloader, passing to train()")
trained_model: ZanjHookedTransformer = train(
cfg=cfg,
dataloader=dataloader,
Expand Down
74 changes: 48 additions & 26 deletions maze_transformer/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jaxtyping import Float
from maze_dataset import MazeDataset, SolvedMaze
from maze_dataset.tokenization import MazeTokenizer
from muutils.mlutils import pprint_summary
from muutils.statcounter import StatCounter
from torch.utils.data import DataLoader
from transformer_lens.HookedTransformer import SingleLoss
Expand All @@ -24,12 +25,19 @@ def collate_batch(batch: list[SolvedMaze], maze_tokenizer: MazeTokenizer) -> lis


def get_dataloader(
dataset: MazeDataset, cfg: ConfigHolder, logger: WandbLogger
dataset: MazeDataset, cfg: ConfigHolder, logger: WandbLogger | None
) -> DataLoader:
def log_progress(msg):
# Convenience function for deciding whether to use logger or not
if logger:
logger.progress(msg)
else:
print(msg)

if len(dataset) == 0:
raise ValueError(f"Dataset is empty: {len(dataset) = }")
logger.progress(f"Loaded {len(dataset)} sequences")
logger.progress("Creating dataloader")
log_progress(f"Loaded {len(dataset)} sequences")
log_progress("Creating dataloader")
try:
dataloader: DataLoader = DataLoader(
dataset,
Expand Down Expand Up @@ -59,32 +67,45 @@ def train(
zanj: ZANJ | None = None,
model: ZanjHookedTransformer | None = None,
) -> ZanjHookedTransformer:
def log(msg: str | dict, log_type: str = "progress", **kwargs):
# Convenience function to let training routine work whether or not
# logger exists
if logger:
log_fn = getattr(logger, log_type)
log_fn(msg, **kwargs)
else:
if type(msg) == dict:
pprint_summary(msg)
else:
print(msg)

# initialize
# ==============================
if zanj is None:
zanj = ZANJ()

# init model & optimizer
if model is None:
logger.progress(f"Initializing model")
log(f"Initializing model")
model: ZanjHookedTransformer = cfg.create_model_zanj()
model.to(device)
else:
logger.progress("Using existing model")
log("Using existing model")

logger.summary({"device": str(device), "model.device": model.cfg.device})
log({"device": str(device), "model.device": model.cfg.device}, log_type="summary")

logger.progress("Initializing optimizer")
log("Initializing optimizer")
optimizer: torch.optim.Optimizer = cfg.train_cfg.optimizer(
model.parameters(),
**cfg.train_cfg.optimizer_kwargs,
)
logger.summary(dict(model_n_params=model.cfg.n_params))
log(dict(model_n_params=model.cfg.n_params), log_type="summary")

# add wandb run url to model
model.training_records = {
"wandb_url": logger.url,
}
if logger:
model.training_records = {
"wandb_url": logger.url,
}

# figure out whether to run evals, and validation dataset
evals_enabled: bool = cfg.train_cfg.validation_dataset_cfg is not None
Expand Down Expand Up @@ -116,10 +137,11 @@ def train(
key: value if not key.startswith("eval") else float("inf")
for key, value in intervals.items()
}
logger.summary(
{"n_batches": n_batches, "n_samples": n_samples, "intervals": intervals}
log(
{"n_batches": n_batches, "n_samples": n_samples, "intervals": intervals},
log_type="summary",
)
logger.progress(
log(
f"will train for {n_batches} batches, {evals_enabled=}, with intervals: {intervals}"
)

Expand All @@ -128,7 +150,7 @@ def train(
# start up training
# ==============================
model.train()
logger.progress("Starting training")
log("Starting training")

for iteration, batch in enumerate(dataloader):
# forward pass
Expand All @@ -153,7 +175,7 @@ def train(
if evals_enabled:
for interval_key, evals_dict in PathEvals.PATH_EVALS_MAP.items():
if iteration % intervals[interval_key] == 0:
logger.progress(f"Running evals: {interval_key}")
log(f"Running evals: {interval_key}")
scores: dict[str, StatCounter] = evaluate_model(
model=model,
dataset=val_dataset,
Expand All @@ -163,12 +185,10 @@ def train(
max_new_tokens=cfg.train_cfg.evals_max_new_tokens,
)
metrics.update(scores)
logger.log_metric_hist(metrics)
log(metrics, log_type="log_metric_hist")

if iteration % intervals["print_loss"] == 0:
logger.progress(
f"iteration {iteration}/{n_batches}: loss={loss.item():.3f}"
)
log(f"iteration {iteration}/{n_batches}: loss={loss.item():.3f}")

del loss

Expand All @@ -180,19 +200,21 @@ def train(
/ TRAIN_SAVE_FILES.checkpoints
/ TRAIN_SAVE_FILES.model_checkpt_zanj(iteration)
)
logger.progress(f"Saving model checkpoint to {model_save_path.as_posix()}")
log(f"Saving model checkpoint to {model_save_path.as_posix()}")
zanj.save(model, model_save_path)
logger.upload_model(
model_save_path, aliases=["latest", f"iter-{iteration}"]
log(
model_save_path,
log_type="upload_model",
aliases=["latest", f"iter-{iteration}"],
)

# save the final model
# ==============================
final_model_path: Path = output_dir / TRAIN_SAVE_FILES.model_final_zanj
logger.progress(f"Saving final model to {final_model_path.as_posix()}")
log(f"Saving final model to {final_model_path.as_posix()}")
zanj.save(model, final_model_path)
logger.upload_model(final_model_path, aliases=["latest", "final"])
log(final_model_path, log_type="upload_model", aliases=["latest", "final"])

logger.progress("Done training!")
log("Done training!")

return model
Loading

0 comments on commit 70d494c

Please sign in to comment.