From c973442d46507e3658a2147c9518e6181b2c5ba4 Mon Sep 17 00:00:00 2001 From: Naveen Arun Date: Thu, 25 Jan 2024 19:04:23 -0600 Subject: [PATCH 1/4] Add conda instructions to README.md --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index 98ea526d..12c671f1 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,20 @@ Most of the functionality is demonstrated in the ipython notebooks in the `noteb * Restart VSCode * In VSCode, select the python interpreter located in `maze-transformer/.venv/bin` as your juptyer kernel +## Instructions for Conda users + +* Create a new Conda environment: `conda create -n mazetransformer python=3.10 poetry` +* Activate the environment: `conda activate mazetransformer` +* Update poetry and install dev dependencies + ``` + poetry self update + poetry config virtualenvs.in-project true + poetry install --with dev + ``` +* Run unit, integration, and notebook tests + ``` + make test + ``` ## Testing & Static analysis From c39070a87c5939f37b4258e2cfd94400ac06fe00 Mon Sep 17 00:00:00 2001 From: naveenarun Date: Fri, 26 Jan 2024 17:31:44 -0600 Subject: [PATCH 2/4] Allow training notebook to run without logger or wandb config --- maze_transformer/training/train_model.py | 60 +- maze_transformer/training/training.py | 75 +- notebooks/train_model.ipynb | 2730 ++++++++++++++++++++-- 3 files changed, 2594 insertions(+), 271 deletions(-) diff --git a/maze_transformer/training/train_model.py b/maze_transformer/training/train_model.py index 3421414b..4dab03f5 100644 --- a/maze_transformer/training/train_model.py +++ b/maze_transformer/training/train_model.py @@ -7,7 +7,7 @@ from maze_dataset import MazeDataset, MazeDatasetConfig from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS from muutils.json_serialize import SerializableDataclass, serializable_dataclass -from muutils.mlutils import get_device +from muutils.mlutils import get_device, pprint_summary from torch.utils.data import DataLoader from maze_transformer.training.config import ( @@ -36,7 +36,7 @@ def __str__(self): def train_model( base_path: str | Path, - wandb_project: Union[WandbProject, str], + wandb_project: Union[WandbProject, str] | None, cfg: ConfigHolder | None = None, cfg_file: str | Path | None = None, cfg_names: typing.Sequence[str] | None = None, @@ -59,6 +59,8 @@ def train_model( - model config names: {model_cfg_names} - train config names: {train_cfg_names} """ + USES_LOGGER : bool = (wandb_project is not None) + if help: print(train_model.__doc__) return @@ -84,14 +86,7 @@ def train_model( (output_path / TRAIN_SAVE_FILES.checkpoints).mkdir(parents=True) # set up logger - logger: WandbLogger = WandbLogger.create( - config=cfg.serialize(), - project=wandb_project, - job_type=WandbJobType.TRAIN_MODEL, - ) - logger.progress("Initialized logger") - logger.summary( - dict( + logger_cfg_dict = dict( logger_cfg={ "output_dir": output_path.as_posix(), "cfg.name": cfg.name, @@ -102,8 +97,32 @@ def train_model( "cfg": cfg.serialize(), }, ) - ) - logger.progress("Summary logged, getting dataset") + + # Set up logger if wanb project is specified + if USES_LOGGER: + logger: WandbLogger = WandbLogger.create( + config=cfg.serialize(), + project=wandb_project, + job_type=WandbJobType.TRAIN_MODEL, + ) + logger.progress("Initialized logger") + else: + logger = None + + def log(msg: str | dict, log_type: str = 'progress', **kwargs): + # Convenience function to let training routine work whether or not + # logger exists + if logger: + log_fn = getattr(logger, log_type) + log_fn(msg, **kwargs) + else: + if type(msg) == dict: + pprint_summary(msg) + else: + print(msg) + + log(logger_cfg_dict, log_type='summary') + log("Summary logged, getting dataset") # load dataset if dataset is None: @@ -115,18 +134,19 @@ def train_model( ) else: if dataset.cfg == cfg.dataset_cfg: - logger.progress(f"passed dataset has matching config, using that") + log(f"passed dataset has matching config, using that") else: if allow_dataset_override: - logger.progress( - f"passed dataset has different config than cfg.dataset_cfg, but allow_dataset_override is True, so using passed dataset" - ) + log(f"passed dataset has different config than cfg.dataset_cfg, but allow_dataset_override is True, so using passed dataset") else: raise ValueError( f"dataset has different config than cfg.dataset_cfg, and allow_dataset_override is False" ) - logger.progress(f"finished getting training dataset with {len(dataset)} samples") + log( + f"finished getting training dataset with {len(dataset)} samples" + ) + # validation dataset, if applicable val_dataset: MazeDataset | None = None if cfg.train_cfg.validation_dataset_cfg is not None: @@ -148,7 +168,7 @@ def train_model( dataset.mazes = dataset.mazes[: split_dataset_sizes[0]] dataset.update_self_config() val_dataset.update_self_config() - logger.progress( + log( f"got validation dataset by splitting training dataset into {len(dataset)} train and {len(val_dataset)} validation samples" ) elif isinstance(cfg.train_cfg.validation_dataset_cfg, MazeDatasetConfig): @@ -158,14 +178,14 @@ def train_model( local_base_path=base_path, verbose=dataset_verbose, ) - logger.progress( + log( f"got custom validation dataset with {len(val_dataset)} samples" ) # get dataloader and then train dataloader: DataLoader = get_dataloader(dataset, cfg, logger) - logger.progress("finished dataloader, passing to train()") + log("finished dataloader, passing to train()") trained_model: ZanjHookedTransformer = train( cfg=cfg, dataloader=dataloader, diff --git a/maze_transformer/training/training.py b/maze_transformer/training/training.py index 15af763d..92af184a 100644 --- a/maze_transformer/training/training.py +++ b/maze_transformer/training/training.py @@ -7,6 +7,7 @@ from maze_dataset import MazeDataset, SolvedMaze from maze_dataset.tokenization import MazeTokenizer from muutils.statcounter import StatCounter +from muutils.mlutils import pprint_summary from torch.utils.data import DataLoader from transformer_lens.HookedTransformer import SingleLoss from zanj import ZANJ @@ -24,12 +25,19 @@ def collate_batch(batch: list[SolvedMaze], maze_tokenizer: MazeTokenizer) -> lis def get_dataloader( - dataset: MazeDataset, cfg: ConfigHolder, logger: WandbLogger + dataset: MazeDataset, cfg: ConfigHolder, logger: WandbLogger | None ) -> DataLoader: + def log_progress(msg): + # Convenience function for deciding whether to use logger or not + if logger: + logger.progress(msg) + else: + print(msg) + if len(dataset) == 0: raise ValueError(f"Dataset is empty: {len(dataset) = }") - logger.progress(f"Loaded {len(dataset)} sequences") - logger.progress("Creating dataloader") + log_progress(f"Loaded {len(dataset)} sequences") + log_progress("Creating dataloader") try: dataloader: DataLoader = DataLoader( dataset, @@ -59,32 +67,46 @@ def train( zanj: ZANJ | None = None, model: ZanjHookedTransformer | None = None, ) -> ZanjHookedTransformer: + + def log(msg: str | dict, log_type: str = 'progress', **kwargs): + # Convenience function to let training routine work whether or not + # logger exists + if logger: + log_fn = getattr(logger, log_type) + log_fn(msg, **kwargs) + else: + if type(msg) == dict: + pprint_summary(msg) + else: + print(msg) + # initialize # ============================== if zanj is None: zanj = ZANJ() - + # init model & optimizer if model is None: - logger.progress(f"Initializing model") + log(f"Initializing model") model: ZanjHookedTransformer = cfg.create_model_zanj() model.to(device) else: - logger.progress("Using existing model") + log("Using existing model") - logger.summary({"device": str(device), "model.device": model.cfg.device}) + log({"device": str(device), "model.device": model.cfg.device}, log_type='summary') - logger.progress("Initializing optimizer") + log("Initializing optimizer") optimizer: torch.optim.Optimizer = cfg.train_cfg.optimizer( model.parameters(), **cfg.train_cfg.optimizer_kwargs, ) - logger.summary(dict(model_n_params=model.cfg.n_params)) + log(dict(model_n_params=model.cfg.n_params), log_type='summary') # add wandb run url to model - model.training_records = { - "wandb_url": logger.url, - } + if logger: + model.training_records = { + "wandb_url": logger.url, + } # figure out whether to run evals, and validation dataset evals_enabled: bool = cfg.train_cfg.validation_dataset_cfg is not None @@ -116,10 +138,11 @@ def train( key: value if not key.startswith("eval") else float("inf") for key, value in intervals.items() } - logger.summary( - {"n_batches": n_batches, "n_samples": n_samples, "intervals": intervals} + log( + {"n_batches": n_batches, "n_samples": n_samples, "intervals": intervals}, + log_type='summary' ) - logger.progress( + log( f"will train for {n_batches} batches, {evals_enabled=}, with intervals: {intervals}" ) @@ -128,7 +151,7 @@ def train( # start up training # ============================== model.train() - logger.progress("Starting training") + log("Starting training") for iteration, batch in enumerate(dataloader): # forward pass @@ -153,7 +176,7 @@ def train( if evals_enabled: for interval_key, evals_dict in PathEvals.PATH_EVALS_MAP.items(): if iteration % intervals[interval_key] == 0: - logger.progress(f"Running evals: {interval_key}") + log(f"Running evals: {interval_key}") scores: dict[str, StatCounter] = evaluate_model( model=model, dataset=val_dataset, @@ -163,10 +186,10 @@ def train( max_new_tokens=cfg.train_cfg.evals_max_new_tokens, ) metrics.update(scores) - logger.log_metric_hist(metrics) + log(metrics, log_type='log_metric_hist') if iteration % intervals["print_loss"] == 0: - logger.progress( + log( f"iteration {iteration}/{n_batches}: loss={loss.item():.3f}" ) @@ -180,19 +203,21 @@ def train( / TRAIN_SAVE_FILES.checkpoints / TRAIN_SAVE_FILES.model_checkpt_zanj(iteration) ) - logger.progress(f"Saving model checkpoint to {model_save_path.as_posix()}") + log(f"Saving model checkpoint to {model_save_path.as_posix()}") zanj.save(model, model_save_path) - logger.upload_model( - model_save_path, aliases=["latest", f"iter-{iteration}"] + log( + model_save_path, + log_type = 'upload_model', + aliases=["latest", f"iter-{iteration}"], ) # save the final model # ============================== final_model_path: Path = output_dir / TRAIN_SAVE_FILES.model_final_zanj - logger.progress(f"Saving final model to {final_model_path.as_posix()}") + log(f"Saving final model to {final_model_path.as_posix()}") zanj.save(model, final_model_path) - logger.upload_model(final_model_path, aliases=["latest", "final"]) + log(final_model_path, log_type='upload_model', aliases=["latest", "final"]) - logger.progress("Done training!") + log("Done training!") return model diff --git a/notebooks/train_model.ipynb b/notebooks/train_model.ipynb index ac3fa2d6..d50d0db3 100644 --- a/notebooks/train_model.ipynb +++ b/notebooks/train_model.ipynb @@ -6,6 +6,9 @@ "metadata": {}, "outputs": [], "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", "# Generic\n", "import typing\n", "import os\n", @@ -37,15 +40,7 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "DEVICE = device(type='cuda')\n" - ] - } - ], + "outputs": [], "source": [ "# set global defaults for ZANJ\n", "ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024\n", @@ -56,8 +51,8 @@ "PATH_DATA: Path = Path(\"../data/\")\n", "\n", "# reproducibility and device\n", - "DEVICE = configure_notebook(seed=42, dark_mode=True)\n", - "print(f\"{DEVICE = }\")" + "#DEVICE = configure_notebook(seed=42, dark_mode=True)\n", + "#print(f\"{DEVICE = }\")" ] }, { @@ -171,17 +166,17 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "# here is where to specify which config to actually use\n", - "CFG: ConfigHolder = CFG_TEST" + "CFG: ConfigHolder = CFG_CUSTOM" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -189,71 +184,63 @@ "output_type": "stream", "text": [ "{\n", - " \"name\": \"hallway-medium\",\n", + " \"name\": \"custom\",\n", " \"dataset_cfg\": {\n", - " \"name\": \"custom-hallway\",\n", - " \"fname\": \"custom-hallway-g8-n3.0M-a_dfs-h66295\",\n", - " \"sdc_hash\": 38558245839505946605550258446345100654876335964881215358923857434720332766295,\n", + " \"name\": \"custom-dataset\",\n", + " \"fname\": \"custom-dataset-g6-n10K-a_dfs-h74621\",\n", + " \"sdc_hash\": 30263437081808734576530068202485318808735944160954955199116767321681194074621,\n", " \"seed\": 42,\n", " \"seq_len_min\": 1,\n", - " \"seq_len_max\": 256,\n", - " \"applied_filters\": [\n", - " {\n", - " \"name\": \"collect_generation_meta\",\n", - " \"args\": [],\n", - " \"kwargs\": {}\n", - " }\n", - " ],\n", - " \"grid_n\": 8,\n", + " \"seq_len_max\": 512,\n", + " \"applied_filters\": [],\n", + " \"grid_n\": 6,\n", " \"grid_shape\": [\n", - " 8,\n", - " 8\n", + " 6,\n", + " 6\n", " ],\n", - " \"n_mazes\": 3000000,\n", + " \"n_mazes\": 10000,\n", " \"maze_ctor_name\": \"gen_dfs\",\n", - " \"maze_ctor_kwargs\": {\n", - " \"do_forks\": false\n", - " }\n", + " \"maze_ctor_kwargs\": {}\n", " },\n", " \"model_cfg\": {\n", " \"name\": \"custom-model\",\n", " \"act_fn\": \"gelu\",\n", - " \"d_model\": 128,\n", - " \"d_head\": 32,\n", - " \"n_layers\": 6,\n", + " \"d_model\": 8,\n", + " \"d_head\": 4,\n", + " \"n_layers\": 2,\n", " \"weight_processing\": {\n", " \"are_layernorms_folded\": false,\n", " \"are_weights_processed\": false\n", " },\n", - " \"n_heads\": 4\n", + " \"n_heads\": 2\n", " },\n", " \"train_cfg\": {\n", " \"name\": \"custom-train\",\n", - " \"optimizer\": \"AdamW\",\n", + " \"optimizer\": \"RMSprop\",\n", " \"optimizer_kwargs\": {\n", - " \"lr\": 0.001\n", + " \"lr\": 0.0001\n", " },\n", - " \"batch_size\": 32,\n", + " \"batch_size\": 16,\n", " \"dataloader_cfg\": {\n", - " \"shuffle\": false,\n", - " \"num_workers\": 8,\n", + " \"shuffle\": true,\n", + " \"num_workers\": 0,\n", " \"drop_last\": false\n", " },\n", " \"intervals\": null,\n", " \"intervals_count\": {\n", " \"print_loss\": 100,\n", - " \"checkpoint\": 20,\n", - " \"eval_fast\": 100,\n", - " \"eval_slow\": 50\n", + " \"checkpoint\": 5,\n", + " \"eval_fast\": 10,\n", + " \"eval_slow\": 5\n", " },\n", " \"evals_max_new_tokens\": 8,\n", - " \"validation_dataset_cfg\": 100\n", + " \"validation_dataset_cfg\": null\n", " },\n", " \"pretrainedtokenizer_kwargs\": null,\n", " \"maze_tokenizer\": {\n", " \"tokenization_mode\": \"AOTP_UT_uniform\",\n", - " \"max_grid_size\": 8,\n", - " \"vocab_size\": 75\n", + " \"max_grid_size\": 6,\n", + " \"vocab_size\": 47\n", " }\n", "}\n" ] @@ -265,32 +252,16 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "trying to get the dataset 'custom-hallway-g8-n3.0M-a_dfs-h66295'\n", - "seeing if we can download the dataset...\n", - "no download found, or download failed\n", - "generating dataset...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "generating & solving mazes: 100%|██████████| 3000000/3000000 [1:00:22<00:00, 828.27maze/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "saving dataset to ..\\data\\custom-hallway-g8-n3.0M-a_dfs-h66295.zanj\n", - "Got dataset custom-hallway with 3000000 items. output.cfg.to_fname() = 'custom-hallway-g8-n3.0M-a_dfs-h66295'\n" + "trying to get the dataset 'custom-dataset-g6-n10K-a_dfs-h74621'\n", + "loading dataset from ../data/custom-dataset-g6-n10K-a_dfs-h74621.zanj\n", + "Got dataset custom-dataset with 10000 items. output.cfg.to_fname() = 'custom-dataset-g6-n10K-a_dfs-h74621'\n" ] } ], @@ -305,170 +276,2477 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "2023-10-02 04:50:44 ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmiv\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] - }, - { - "data": { - "text/html": [ - "wandb version 0.15.11 is available! To upgrade, please run:\n", - " $ pip install wandb --upgrade" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Tracking run with wandb version 0.13.11" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Run data is saved locally in f:\\KNC\\maze-transformer\\notebooks\\wandb\\run-20231002_045046-sn2icf41" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Syncing run peach-dream-10 to Weights & Biases (docs)
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View project at https://wandb.ai/miv/understanding-search" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View run at https://wandb.ai/miv/understanding-search/runs/sn2icf41" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-10-02 04:50:47 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'custom-hallway', 'seq_len_min': 1, 'seq_len_max': 256, 'seed': 42, 'applied_filters': [{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}], 'grid_n': 8, 'n_mazes': 3000000, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_dataset.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', ' # Arguments', ' - `grid_shape: Coord`: the shape of the grid', ' - `lattice_dim: int`: the dimension of the lattice', ' (default: `2`)', ' - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**', ' (default: `None`)', ' - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**', ' (default: `None`)', ' - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.', ' - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.', '', ' # algorithm', ' 1. Choose the initial cell, mark it as visited and push it to the stack', ' 2. While the stack is not empty', ' 1. Pop a cell from the stack and make it a current cell', ' 2. If the current cell has any neighbours which have not been visited', ' 1. Push the current cell to the stack', ' 2. Choose one of the unvisited neighbours', ' 3. Remove the wall between the current cell and the chosen cell', ' 4. Mark the chosen cell as visited and push it to the stack', ' '], 'source_code': [' @staticmethod', ' def gen_dfs(', ' grid_shape: Coord,', ' lattice_dim: int = 2,', ' accessible_cells: int | float | None = None,', ' max_tree_depth: int | float | None = None,', ' do_forks: bool = True,', ' randomized_stack: bool = False,', ' start_coord: Coord | None = None,', ' ) -> LatticeMaze:', ' \"\"\"generate a lattice maze using depth first search, iterative', '', ' # Arguments', ' - `grid_shape: Coord`: the shape of the grid', ' - `lattice_dim: int`: the dimension of the lattice', ' (default: `2`)', ' - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**', ' (default: `None`)', ' - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**', ' (default: `None`)', ' - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.', ' - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.', '', ' # algorithm', ' 1. Choose the initial cell, mark it as visited and push it to the stack', ' 2. While the stack is not empty', ' 1. Pop a cell from the stack and make it a current cell', ' 2. If the current cell has any neighbours which have not been visited', ' 1. Push the current cell to the stack', ' 2. Choose one of the unvisited neighbours', ' 3. Remove the wall between the current cell and the chosen cell', ' 4. Mark the chosen cell as visited and push it to the stack', ' \"\"\"', '', ' # Default values if no constraints have been passed', ' grid_shape: Coord = np.array(grid_shape)', ' n_total_cells: int = int(np.prod(grid_shape))', '', ' n_accessible_cells: int', ' if accessible_cells is None:', ' n_accessible_cells = n_total_cells', ' elif isinstance(accessible_cells, float):', ' assert (', ' accessible_cells <= 1', ' ), f\"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}\"', '', ' n_accessible_cells = int(accessible_cells * n_total_cells)', ' else:', ' assert isinstance(accessible_cells, int)', ' n_accessible_cells = accessible_cells', '', ' if max_tree_depth is None:', ' max_tree_depth = (', ' 2 * n_total_cells', ' ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.', ' elif isinstance(max_tree_depth, float):', ' assert (', ' max_tree_depth <= 1', ' ), f\"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}\"', '', ' max_tree_depth = int(max_tree_depth * np.sum(grid_shape))', '', ' # choose a random start coord', ' start_coord = _random_start_coord(grid_shape, start_coord)', '', ' # initialize the maze with no connections', ' connection_list: ConnectionList = np.zeros(', ' (lattice_dim, grid_shape[0], grid_shape[1]), dtype=np.bool_', ' )', '', ' # initialize the stack with the target coord', ' visited_cells: set[tuple[int, int]] = set()', ' visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol', ' stack: list[Coord] = [start_coord]', '', ' # initialize tree_depth_counter', ' current_tree_depth: int = 1', '', ' # loop until the stack is empty or n_connected_cells is reached', ' while stack and (len(visited_cells) < n_accessible_cells):', ' # get the current coord from the stack', ' current_coord: Coord', ' if randomized_stack:', ' current_coord = stack.pop(random.randint(0, len(stack) - 1))', ' else:', ' current_coord = stack.pop()', '', ' # filter neighbors by being within grid bounds and being unvisited', ' unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [', ' (neighbor, delta)', ' for neighbor, delta in zip(', ' current_coord + NEIGHBORS_MASK, NEIGHBORS_MASK', ' )', ' if (', ' (tuple(neighbor) not in visited_cells)', ' and (0 <= neighbor[0] < grid_shape[0])', ' and (0 <= neighbor[1] < grid_shape[1])', ' )', ' ]', '', \" # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)\", ' if unvisited_neighbors_deltas and (', ' current_tree_depth <= max_tree_depth / 2', ' ):', \" # if we want a maze without forks, simply don't add the current coord back to the stack\", ' if do_forks and (len(unvisited_neighbors_deltas) > 1):', ' stack.append(current_coord)', '', ' # choose one of the unvisited neighbors', ' chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)', '', ' # add connection', ' dim: int = np.argmax(np.abs(delta))', ' # if positive, down/right from current coord', ' # if negative, up/left from current coord (down/right from neighbor)', ' clist_node: Coord = (', ' current_coord if (delta.sum() > 0) else chosen_neighbor', ' )', ' connection_list[dim, clist_node[0], clist_node[1]] = True', '', ' # add to visited cells and stack', ' visited_cells.add(tuple(chosen_neighbor))', ' stack.append(chosen_neighbor)', '', ' # Update current tree depth', ' current_tree_depth += 1', ' else:', ' current_tree_depth -= 1', '', ' output = LatticeMaze(', ' connection_list=connection_list,', ' generation_meta=dict(', ' func_name=\"gen_dfs\",', ' grid_shape=grid_shape,', ' start_coord=start_coord,', ' n_accessible_cells=int(n_accessible_cells),', ' max_tree_depth=int(max_tree_depth),', \" # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug\", ' # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is', ' # treated as fully connected even when it is most certainly not, causing solving the maze to break', ' fully_connected=bool(len(visited_cells) == n_total_cells),', ' visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},', ' ),', ' )', '', ' return output']}, 'maze_ctor_kwargs': {'do_forks': False}, 'grid_shape': (8, 8)}, 'model_cfg': {'__format__': 'BaseGPTConfig(SerializableDataclass)', 'name': 'custom-model', 'act_fn': 'gelu', 'd_model': 128, 'd_head': 32, 'n_layers': 6, 'weight_processing': {'are_layernorms_folded': False, 'are_weights_processed': False}, 'n_heads': 4}, 'train_cfg': {'__format__': 'TrainConfig(SerializableDataclass)', 'name': 'custom-train', 'evals_max_new_tokens': 8, 'validation_dataset_cfg': 100, 'optimizer': 'AdamW', 'optimizer_kwargs': {'lr': 0.001}, 'batch_size': 32, 'dataloader_cfg': {'shuffle': False, 'num_workers': 8, 'drop_last': False}, 'intervals': None, 'intervals_count': {'print_loss': 100, 'checkpoint': 20, 'eval_fast': 100, 'eval_slow': 50}}, 'name': 'hallway-medium', 'pretrainedtokenizer_kwargs': None, 'maze_tokenizer': {'__format__': 'MazeTokenizer(SerializableDataclass)', 'tokenization_mode': 'AOTP_UT_uniform', 'max_grid_size': 8, 'name': 'maze_tokenizer-AOTP_UT_uniform-g8', 'token_arr': ['', '', '', '', '', '', '', '', '<-->', ';', '', '(0,0)', '(0,1)', '(1,0)', '(1,1)', '(0,2)', '(2,0)', '(1,2)', '(2,1)', '(2,2)', '(0,3)', '(3,0)', '(3,1)', '(2,3)', '(3,2)', '(1,3)', '(3,3)', '(0,4)', '(2,4)', '(4,0)', '(1,4)', '(4,1)', '(4,2)', '(3,4)', '(4,3)', '(4,4)', '(0,5)', '(5,0)', '(5,1)', '(2,5)', '(5,2)', '(5,3)', '(4,5)', '(5,4)', '(1,5)', '(3,5)', '(5,5)', '(0,6)', '(2,6)', '(4,6)', '(6,0)', '(1,6)', '(6,1)', '(6,2)', '(3,6)', '(6,3)', '(6,4)', '(5,6)', '(6,5)', '(6,6)', '(0,7)', '(7,0)', '(7,1)', '(2,7)', '(7,2)', '(7,3)', '(4,7)', '(7,4)', '(7,5)', '(6,7)', '(7,6)', '(1,7)', '(3,7)', '(5,7)', '(7,7)'], 'tokenizer_map': {'': 0, '': 1, '': 2, '': 3, '': 4, '': 5, '': 6, '': 7, '<-->': 8, ';': 9, '': 10, '(0,0)': 11, '(0,1)': 12, '(1,0)': 13, '(1,1)': 14, '(0,2)': 15, '(2,0)': 16, '(1,2)': 17, '(2,1)': 18, '(2,2)': 19, '(0,3)': 20, '(3,0)': 21, '(3,1)': 22, '(2,3)': 23, '(3,2)': 24, '(1,3)': 25, '(3,3)': 26, '(0,4)': 27, '(2,4)': 28, '(4,0)': 29, '(1,4)': 30, '(4,1)': 31, '(4,2)': 32, '(3,4)': 33, '(4,3)': 34, '(4,4)': 35, '(0,5)': 36, '(5,0)': 37, '(5,1)': 38, '(2,5)': 39, '(5,2)': 40, '(5,3)': 41, '(4,5)': 42, '(5,4)': 43, '(1,5)': 44, '(3,5)': 45, '(5,5)': 46, '(0,6)': 47, '(2,6)': 48, '(4,6)': 49, '(6,0)': 50, '(1,6)': 51, '(6,1)': 52, '(6,2)': 53, '(3,6)': 54, '(6,3)': 55, '(6,4)': 56, '(5,6)': 57, '(6,5)': 58, '(6,6)': 59, '(0,7)': 60, '(7,0)': 61, '(7,1)': 62, '(2,7)': 63, '(7,2)': 64, '(7,3)': 65, '(4,7)': 66, '(7,4)': 67, '(7,5)': 68, '(6,7)': 69, '(7,6)': 70, '(1,7)': 71, '(3,7)': 72, '(5,7)': 73, '(7,7)': 74}, 'vocab_size': 75, 'padding_token_index': 10}}\n", - "2023-10-02 04:50:47 INFO Initialized logger\n", - "2023-10-02 04:50:47 INFO Summary logged, getting dataset\n", - "2023-10-02 04:50:47 INFO passed dataset has matching config, using that\n", - "2023-10-02 04:50:47 INFO finished getting training dataset with 3000000 samples\n", - "2023-10-02 04:50:47 INFO got validation dataset by splitting training dataset into 2999900 train and 100 validation samples\n", - "2023-10-02 04:50:47 INFO Loaded 2999900 sequences\n", - "2023-10-02 04:50:47 INFO Creating dataloader\n", - "2023-10-02 04:50:47 INFO finished dataloader, passing to train()\n", - "2023-10-02 04:50:47 INFO Initializing model\n", - "Moving model to device: cuda\n", - "2023-10-02 04:50:47 INFO Initializing optimizer\n", - "2023-10-02 04:50:47 INFO will train for 93747 batches, evals_enabled=True, with intervals: {'print_loss': 937, 'checkpoint': 4687, 'eval_fast': 937, 'eval_slow': 1874}\n", - "2023-10-02 04:50:47 INFO Starting training\n" - ] - }, - { - "ename": "MemoryError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mMemoryError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[8], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m result: TrainingResult \u001b[39m=\u001b[39m train_model(\n\u001b[0;32m 2\u001b[0m \tbase_path\u001b[39m=\u001b[39;49mPATH_DATA,\n\u001b[0;32m 3\u001b[0m cfg\u001b[39m=\u001b[39;49mCFG,\n\u001b[0;32m 4\u001b[0m \twandb_project\u001b[39m=\u001b[39;49mWandbProject\u001b[39m.\u001b[39;49mUNDERSTANDING_SEARCH, \u001b[39m# change this to WandbProject.DEMO_NOTEBOOKS!\u001b[39;49;00m\n\u001b[0;32m 5\u001b[0m \tdo_generate_dataset\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[0;32m 6\u001b[0m \tdataset_verbose\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[0;32m 7\u001b[0m dataset\u001b[39m=\u001b[39;49mDATASET,\n\u001b[0;32m 8\u001b[0m )\n", - "File \u001b[1;32mF:\\KNC\\maze-transformer\\maze_transformer\\training\\train_model.py:169\u001b[0m, in \u001b[0;36mtrain_model\u001b[1;34m(base_path, wandb_project, cfg, cfg_file, cfg_names, do_generate_dataset, dataset_verbose, dataset, allow_dataset_override, device, help, **kwargs)\u001b[0m\n\u001b[0;32m 166\u001b[0m dataloader: DataLoader \u001b[39m=\u001b[39m get_dataloader(dataset, cfg, logger)\n\u001b[0;32m 168\u001b[0m logger\u001b[39m.\u001b[39mprogress(\u001b[39m\"\u001b[39m\u001b[39mfinished dataloader, passing to train()\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m--> 169\u001b[0m trained_model: ZanjHookedTransformer \u001b[39m=\u001b[39m train(\n\u001b[0;32m 170\u001b[0m cfg\u001b[39m=\u001b[39;49mcfg,\n\u001b[0;32m 171\u001b[0m dataloader\u001b[39m=\u001b[39;49mdataloader,\n\u001b[0;32m 172\u001b[0m logger\u001b[39m=\u001b[39;49mlogger,\n\u001b[0;32m 173\u001b[0m output_dir\u001b[39m=\u001b[39;49moutput_path,\n\u001b[0;32m 174\u001b[0m device\u001b[39m=\u001b[39;49mdevice,\n\u001b[0;32m 175\u001b[0m val_dataset\u001b[39m=\u001b[39;49mval_dataset,\n\u001b[0;32m 176\u001b[0m )\n\u001b[0;32m 178\u001b[0m \u001b[39mreturn\u001b[39;00m TrainingResult(\n\u001b[0;32m 179\u001b[0m output_path\u001b[39m=\u001b[39moutput_path,\n\u001b[0;32m 180\u001b[0m model\u001b[39m=\u001b[39mtrained_model,\n\u001b[0;32m 181\u001b[0m )\n", - "File \u001b[1;32mF:\\KNC\\maze-transformer\\maze_transformer\\training\\training.py:133\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m(cfg, dataloader, logger, output_dir, device, val_dataset, zanj, model)\u001b[0m\n\u001b[0;32m 130\u001b[0m model\u001b[39m.\u001b[39mtrain()\n\u001b[0;32m 131\u001b[0m logger\u001b[39m.\u001b[39mprogress(\u001b[39m\"\u001b[39m\u001b[39mStarting training\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m--> 133\u001b[0m \u001b[39mfor\u001b[39;00m iteration, batch \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39;49m(dataloader):\n\u001b[0;32m 134\u001b[0m \u001b[39m# forward pass\u001b[39;00m\n\u001b[0;32m 135\u001b[0m \u001b[39m# ------------------------------\u001b[39;00m\n\u001b[0;32m 136\u001b[0m loss: SingleLoss\n\u001b[0;32m 137\u001b[0m logits: Float[torch\u001b[39m.\u001b[39mTensor, \u001b[39m\"\u001b[39m\u001b[39mbatch pos d_vocab\u001b[39m\u001b[39m\"\u001b[39m]\n", - "File \u001b[1;32mc:\\Users\\mivan\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\maze-transformer-2cGx2R0F-py3.10\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:441\u001b[0m, in \u001b[0;36mDataLoader.__iter__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 439\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_iterator\n\u001b[0;32m 440\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m--> 441\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_get_iterator()\n", - "File \u001b[1;32mc:\\Users\\mivan\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\maze-transformer-2cGx2R0F-py3.10\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:388\u001b[0m, in \u001b[0;36mDataLoader._get_iterator\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 386\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m 387\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcheck_worker_number_rationality()\n\u001b[1;32m--> 388\u001b[0m \u001b[39mreturn\u001b[39;00m _MultiProcessingDataLoaderIter(\u001b[39mself\u001b[39;49m)\n", - "File \u001b[1;32mc:\\Users\\mivan\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\maze-transformer-2cGx2R0F-py3.10\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:1042\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter.__init__\u001b[1;34m(self, loader)\u001b[0m\n\u001b[0;32m 1035\u001b[0m w\u001b[39m.\u001b[39mdaemon \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[0;32m 1036\u001b[0m \u001b[39m# NB: Process.start() actually take some time as it needs to\u001b[39;00m\n\u001b[0;32m 1037\u001b[0m \u001b[39m# start a process and pass the arguments over via a pipe.\u001b[39;00m\n\u001b[0;32m 1038\u001b[0m \u001b[39m# Therefore, we only add a worker to self._workers list after\u001b[39;00m\n\u001b[0;32m 1039\u001b[0m \u001b[39m# it started, so that we do not call .join() if program dies\u001b[39;00m\n\u001b[0;32m 1040\u001b[0m \u001b[39m# before it starts, and __del__ tries to join but will get:\u001b[39;00m\n\u001b[0;32m 1041\u001b[0m \u001b[39m# AssertionError: can only join a started process.\u001b[39;00m\n\u001b[1;32m-> 1042\u001b[0m w\u001b[39m.\u001b[39;49mstart()\n\u001b[0;32m 1043\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_index_queues\u001b[39m.\u001b[39mappend(index_queue)\n\u001b[0;32m 1044\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_workers\u001b[39m.\u001b[39mappend(w)\n", - "File \u001b[1;32mC:\\Python\\Python3_10\\lib\\multiprocessing\\process.py:121\u001b[0m, in \u001b[0;36mBaseProcess.start\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 118\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mnot\u001b[39;00m _current_process\u001b[39m.\u001b[39m_config\u001b[39m.\u001b[39mget(\u001b[39m'\u001b[39m\u001b[39mdaemon\u001b[39m\u001b[39m'\u001b[39m), \\\n\u001b[0;32m 119\u001b[0m \u001b[39m'\u001b[39m\u001b[39mdaemonic processes are not allowed to have children\u001b[39m\u001b[39m'\u001b[39m\n\u001b[0;32m 120\u001b[0m _cleanup()\n\u001b[1;32m--> 121\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_popen \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_Popen(\u001b[39mself\u001b[39;49m)\n\u001b[0;32m 122\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sentinel \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_popen\u001b[39m.\u001b[39msentinel\n\u001b[0;32m 123\u001b[0m \u001b[39m# Avoid a refcycle if the target function holds an indirect\u001b[39;00m\n\u001b[0;32m 124\u001b[0m \u001b[39m# reference to the process object (see bpo-30775)\u001b[39;00m\n", - "File \u001b[1;32mC:\\Python\\Python3_10\\lib\\multiprocessing\\context.py:224\u001b[0m, in \u001b[0;36mProcess._Popen\u001b[1;34m(process_obj)\u001b[0m\n\u001b[0;32m 222\u001b[0m \u001b[39m@staticmethod\u001b[39m\n\u001b[0;32m 223\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_Popen\u001b[39m(process_obj):\n\u001b[1;32m--> 224\u001b[0m \u001b[39mreturn\u001b[39;00m _default_context\u001b[39m.\u001b[39;49mget_context()\u001b[39m.\u001b[39;49mProcess\u001b[39m.\u001b[39;49m_Popen(process_obj)\n", - "File \u001b[1;32mC:\\Python\\Python3_10\\lib\\multiprocessing\\context.py:327\u001b[0m, in \u001b[0;36mSpawnProcess._Popen\u001b[1;34m(process_obj)\u001b[0m\n\u001b[0;32m 324\u001b[0m \u001b[39m@staticmethod\u001b[39m\n\u001b[0;32m 325\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_Popen\u001b[39m(process_obj):\n\u001b[0;32m 326\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m\u001b[39mpopen_spawn_win32\u001b[39;00m \u001b[39mimport\u001b[39;00m Popen\n\u001b[1;32m--> 327\u001b[0m \u001b[39mreturn\u001b[39;00m Popen(process_obj)\n", - "File \u001b[1;32mC:\\Python\\Python3_10\\lib\\multiprocessing\\popen_spawn_win32.py:93\u001b[0m, in \u001b[0;36mPopen.__init__\u001b[1;34m(self, process_obj)\u001b[0m\n\u001b[0;32m 91\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m 92\u001b[0m reduction\u001b[39m.\u001b[39mdump(prep_data, to_child)\n\u001b[1;32m---> 93\u001b[0m reduction\u001b[39m.\u001b[39;49mdump(process_obj, to_child)\n\u001b[0;32m 94\u001b[0m \u001b[39mfinally\u001b[39;00m:\n\u001b[0;32m 95\u001b[0m set_spawning_popen(\u001b[39mNone\u001b[39;00m)\n", - "File \u001b[1;32mC:\\Python\\Python3_10\\lib\\multiprocessing\\reduction.py:60\u001b[0m, in \u001b[0;36mdump\u001b[1;34m(obj, file, protocol)\u001b[0m\n\u001b[0;32m 58\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdump\u001b[39m(obj, file, protocol\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m):\n\u001b[0;32m 59\u001b[0m \u001b[39m \u001b[39m\u001b[39m'''Replacement for pickle.dump() using ForkingPickler.'''\u001b[39;00m\n\u001b[1;32m---> 60\u001b[0m ForkingPickler(file, protocol)\u001b[39m.\u001b[39;49mdump(obj)\n", - "\u001b[1;31mMemoryError\u001b[0m: " - ] - } - ], - "source": [ - "result: TrainingResult = train_model(\n", - "\tbase_path=PATH_DATA,\n", - " cfg=CFG,\n", - "\twandb_project=WandbProject.UNDERSTANDING_SEARCH, # change this to WandbProject.DEMO_NOTEBOOKS!\n", - "\tdo_generate_dataset=False,\n", - "\tdataset_verbose=True,\n", - " dataset=DATASET,\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "maze-transformer", - "language": "python", - "name": "maze-transformer" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.1" - }, - "orig_nbformat": 4 + "{\n", + " \"logger_cfg\": {\n", + " \"output_dir\": \"../data/custom_2024-01-26-17-07-30\",\n", + " \"cfg.name\": \"custom\",\n", + " \"data_cfg.name\": \"custom-dataset\",\n", + " \"train_cfg.name\": \"custom-train\",\n", + " \"model_cfg.name\": \"custom-model\",\n", + " \"cfg_summary\": {\n", + " \"name\": \"custom\",\n", + " \"dataset_cfg\": {\n", + " \"name\": \"custom-dataset\",\n", + " \"fname\": \"custom-dataset-g6-n10K-a_dfs-h74621\",\n", + " \"sdc_hash\": 30263437081808734576530068202485318808735944160954955199116767321681194074621,\n", + " \"seed\": 42,\n", + " \"seq_len_min\": 1,\n", + " \"seq_len_max\": 512,\n", + " \"applied_filters\": [],\n", + " \"grid_n\": 6,\n", + " \"grid_shape\": [\n", + " 6,\n", + " 6\n", + " ],\n", + " \"n_mazes\": 10000,\n", + " \"maze_ctor_name\": \"gen_dfs\",\n", + " \"maze_ctor_kwargs\": {}\n", + " },\n", + " \"model_cfg\": {\n", + " \"name\": \"custom-model\",\n", + " \"act_fn\": \"gelu\",\n", + " \"d_model\": 8,\n", + " \"d_head\": 4,\n", + " \"n_layers\": 2,\n", + " \"weight_processing\": {\n", + " \"are_layernorms_folded\": false,\n", + " \"are_weights_processed\": false\n", + " },\n", + " \"n_heads\": 2\n", + " },\n", + " \"train_cfg\": {\n", + " \"name\": \"custom-train\",\n", + " \"optimizer\": \"RMSprop\",\n", + " \"optimizer_kwargs\": {\n", + " \"lr\": 0.0001\n", + " },\n", + " \"batch_size\": 16,\n", + " \"dataloader_cfg\": {\n", + " \"shuffle\": true,\n", + " \"num_workers\": 0,\n", + " \"drop_last\": false\n", + " },\n", + " \"intervals\": null,\n", + " \"intervals_count\": {\n", + " \"print_loss\": 100,\n", + " \"checkpoint\": 5,\n", + " \"eval_fast\": 10,\n", + " \"eval_slow\": 5\n", + " },\n", + " \"evals_max_new_tokens\": 8,\n", + " \"validation_dataset_cfg\": null\n", + " },\n", + " \"pretrainedtokenizer_kwargs\": null,\n", + " \"maze_tokenizer\": {\n", + " \"tokenization_mode\": \"AOTP_UT_uniform\",\n", + " \"max_grid_size\": 6,\n", + " \"vocab_size\": 47\n", + " }\n", + " },\n", + " \"cfg\": {\n", + " \"__format__\": \"ConfigHolder(SerializableDataclass)\",\n", + " \"dataset_cfg\": {\n", + " \"__format__\": \"MazeDatasetConfig(SerializableDataclass)\",\n", + " \"name\": \"custom-dataset\",\n", + " \"seq_len_min\": 1,\n", + " \"seq_len_max\": 512,\n", + " \"seed\": 42,\n", + " \"applied_filters\": [],\n", + " \"grid_n\": 6,\n", + " \"n_mazes\": 10000,\n", + " \"maze_ctor\": {\n", + " \"__name__\": \"gen_dfs\",\n", + " \"__module__\": \"maze_dataset.generation.generators\",\n", + " \"__doc__\": [\n", + " \"generate a lattice maze using depth first search, iterative\",\n", + " \"\",\n", + " \" # Arguments\",\n", + " \" - `grid_shape: Coord`: the shape of the grid\",\n", + " \" - `lattice_dim: int`: the dimension of the lattice\",\n", + " \" (default: `2`)\",\n", + " \" - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**\",\n", + " \" (default: `None`)\",\n", + " \" - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**\",\n", + " \" (default: `None`)\",\n", + " \" - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.\",\n", + " \" - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.\",\n", + " \"\",\n", + " \" # algorithm\",\n", + " \" 1. Choose the initial cell, mark it as visited and push it to the stack\",\n", + " \" 2. While the stack is not empty\",\n", + " \" 1. Pop a cell from the stack and make it a current cell\",\n", + " \" 2. If the current cell has any neighbours which have not been visited\",\n", + " \" 1. Push the current cell to the stack\",\n", + " \" 2. Choose one of the unvisited neighbours\",\n", + " \" 3. Remove the wall between the current cell and the chosen cell\",\n", + " \" 4. Mark the chosen cell as visited and push it to the stack\",\n", + " \" \"\n", + " ],\n", + " \"source_code\": [\n", + " \" @staticmethod\",\n", + " \" def gen_dfs(\",\n", + " \" grid_shape: Coord,\",\n", + " \" lattice_dim: int = 2,\",\n", + " \" accessible_cells: int | float | None = None,\",\n", + " \" max_tree_depth: int | float | None = None,\",\n", + " \" do_forks: bool = True,\",\n", + " \" randomized_stack: bool = False,\",\n", + " \" start_coord: Coord | None = None,\",\n", + " \" ) -> LatticeMaze:\",\n", + " \" \\\"\\\"\\\"generate a lattice maze using depth first search, iterative\",\n", + " \"\",\n", + " \" # Arguments\",\n", + " \" - `grid_shape: Coord`: the shape of the grid\",\n", + " \" - `lattice_dim: int`: the dimension of the lattice\",\n", + " \" (default: `2`)\",\n", + " \" - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**\",\n", + " \" (default: `None`)\",\n", + " \" - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**\",\n", + " \" (default: `None`)\",\n", + " \" - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.\",\n", + " \" - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.\",\n", + " \"\",\n", + " \" # algorithm\",\n", + " \" 1. Choose the initial cell, mark it as visited and push it to the stack\",\n", + " \" 2. While the stack is not empty\",\n", + " \" 1. Pop a cell from the stack and make it a current cell\",\n", + " \" 2. If the current cell has any neighbours which have not been visited\",\n", + " \" 1. Push the current cell to the stack\",\n", + " \" 2. Choose one of the unvisited neighbours\",\n", + " \" 3. Remove the wall between the current cell and the chosen cell\",\n", + " \" 4. Mark the chosen cell as visited and push it to the stack\",\n", + " \" \\\"\\\"\\\"\",\n", + " \"\",\n", + " \" # Default values if no constraints have been passed\",\n", + " \" grid_shape: Coord = np.array(grid_shape)\",\n", + " \" n_total_cells: int = int(np.prod(grid_shape))\",\n", + " \"\",\n", + " \" n_accessible_cells: int\",\n", + " \" if accessible_cells is None:\",\n", + " \" n_accessible_cells = n_total_cells\",\n", + " \" elif isinstance(accessible_cells, float):\",\n", + " \" assert (\",\n", + " \" accessible_cells <= 1\",\n", + " \" ), f\\\"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}\\\"\",\n", + " \"\",\n", + " \" n_accessible_cells = int(accessible_cells * n_total_cells)\",\n", + " \" else:\",\n", + " \" assert isinstance(accessible_cells, int)\",\n", + " \" n_accessible_cells = accessible_cells\",\n", + " \"\",\n", + " \" if max_tree_depth is None:\",\n", + " \" max_tree_depth = (\",\n", + " \" 2 * n_total_cells\",\n", + " \" ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.\",\n", + " \" elif isinstance(max_tree_depth, float):\",\n", + " \" assert (\",\n", + " \" max_tree_depth <= 1\",\n", + " \" ), f\\\"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}\\\"\",\n", + " \"\",\n", + " \" max_tree_depth = int(max_tree_depth * np.sum(grid_shape))\",\n", + " \"\",\n", + " \" # choose a random start coord\",\n", + " \" start_coord = _random_start_coord(grid_shape, start_coord)\",\n", + " \"\",\n", + " \" # initialize the maze with no connections\",\n", + " \" connection_list: ConnectionList = np.zeros(\",\n", + " \" (lattice_dim, grid_shape[0], grid_shape[1]), dtype=np.bool_\",\n", + " \" )\",\n", + " \"\",\n", + " \" # initialize the stack with the target coord\",\n", + " \" visited_cells: set[tuple[int, int]] = set()\",\n", + " \" visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol\",\n", + " \" stack: list[Coord] = [start_coord]\",\n", + " \"\",\n", + " \" # initialize tree_depth_counter\",\n", + " \" current_tree_depth: int = 1\",\n", + " \"\",\n", + " \" # loop until the stack is empty or n_connected_cells is reached\",\n", + " \" while stack and (len(visited_cells) < n_accessible_cells):\",\n", + " \" # get the current coord from the stack\",\n", + " \" current_coord: Coord\",\n", + " \" if randomized_stack:\",\n", + " \" current_coord = stack.pop(random.randint(0, len(stack) - 1))\",\n", + " \" else:\",\n", + " \" current_coord = stack.pop()\",\n", + " \"\",\n", + " \" # filter neighbors by being within grid bounds and being unvisited\",\n", + " \" unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [\",\n", + " \" (neighbor, delta)\",\n", + " \" for neighbor, delta in zip(\",\n", + " \" current_coord + NEIGHBORS_MASK, NEIGHBORS_MASK\",\n", + " \" )\",\n", + " \" if (\",\n", + " \" (tuple(neighbor) not in visited_cells)\",\n", + " \" and (0 <= neighbor[0] < grid_shape[0])\",\n", + " \" and (0 <= neighbor[1] < grid_shape[1])\",\n", + " \" )\",\n", + " \" ]\",\n", + " \"\",\n", + " \" # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)\",\n", + " \" if unvisited_neighbors_deltas and (\",\n", + " \" current_tree_depth <= max_tree_depth / 2\",\n", + " \" ):\",\n", + " \" # if we want a maze without forks, simply don't add the current coord back to the stack\",\n", + " \" if do_forks and (len(unvisited_neighbors_deltas) > 1):\",\n", + " \" stack.append(current_coord)\",\n", + " \"\",\n", + " \" # choose one of the unvisited neighbors\",\n", + " \" chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)\",\n", + " \"\",\n", + " \" # add connection\",\n", + " \" dim: int = np.argmax(np.abs(delta))\",\n", + " \" # if positive, down/right from current coord\",\n", + " \" # if negative, up/left from current coord (down/right from neighbor)\",\n", + " \" clist_node: Coord = (\",\n", + " \" current_coord if (delta.sum() > 0) else chosen_neighbor\",\n", + " \" )\",\n", + " \" connection_list[dim, clist_node[0], clist_node[1]] = True\",\n", + " \"\",\n", + " \" # add to visited cells and stack\",\n", + " \" visited_cells.add(tuple(chosen_neighbor))\",\n", + " \" stack.append(chosen_neighbor)\",\n", + " \"\",\n", + " \" # Update current tree depth\",\n", + " \" current_tree_depth += 1\",\n", + " \" else:\",\n", + " \" current_tree_depth -= 1\",\n", + " \"\",\n", + " \" output = LatticeMaze(\",\n", + " \" connection_list=connection_list,\",\n", + " \" generation_meta=dict(\",\n", + " \" func_name=\\\"gen_dfs\\\",\",\n", + " \" grid_shape=grid_shape,\",\n", + " \" start_coord=start_coord,\",\n", + " \" n_accessible_cells=int(n_accessible_cells),\",\n", + " \" max_tree_depth=int(max_tree_depth),\",\n", + " \" # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug\",\n", + " \" # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is\",\n", + " \" # treated as fully connected even when it is most certainly not, causing solving the maze to break\",\n", + " \" fully_connected=bool(len(visited_cells) == n_total_cells),\",\n", + " \" visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},\",\n", + " \" ),\",\n", + " \" )\",\n", + " \"\",\n", + " \" return output\"\n", + " ]\n", + " },\n", + " \"maze_ctor_kwargs\": {},\n", + " \"grid_shape\": [\n", + " 6,\n", + " 6\n", + " ]\n", + " },\n", + " \"model_cfg\": {\n", + " \"__format__\": \"BaseGPTConfig(SerializableDataclass)\",\n", + " \"name\": \"custom-model\",\n", + " \"act_fn\": \"gelu\",\n", + " \"d_model\": 8,\n", + " \"d_head\": 4,\n", + " \"n_layers\": 2,\n", + " \"weight_processing\": {\n", + " \"are_layernorms_folded\": false,\n", + " \"are_weights_processed\": false\n", + " },\n", + " \"n_heads\": 2\n", + " },\n", + " \"train_cfg\": {\n", + " \"__format__\": \"TrainConfig(SerializableDataclass)\",\n", + " \"name\": \"custom-train\",\n", + " \"evals_max_new_tokens\": 8,\n", + " \"validation_dataset_cfg\": null,\n", + " \"optimizer\": \"RMSprop\",\n", + " \"optimizer_kwargs\": {\n", + " \"lr\": 0.0001\n", + " },\n", + " \"batch_size\": 16,\n", + " \"dataloader_cfg\": {\n", + " \"shuffle\": true,\n", + " \"num_workers\": 0,\n", + " \"drop_last\": false\n", + " },\n", + " \"intervals\": null,\n", + " \"intervals_count\": {\n", + " \"print_loss\": 100,\n", + " \"checkpoint\": 5,\n", + " \"eval_fast\": 10,\n", + " \"eval_slow\": 5\n", + " }\n", + " },\n", + " \"name\": \"custom\",\n", + " \"pretrainedtokenizer_kwargs\": null,\n", + " \"maze_tokenizer\": {\n", + " \"__format__\": \"MazeTokenizer(SerializableDataclass)\",\n", + " \"tokenization_mode\": \"AOTP_UT_uniform\",\n", + " \"max_grid_size\": 6,\n", + " \"name\": \"maze_tokenizer-AOTP_UT_uniform-g6\",\n", + " \"token_arr\": [\n", + " \"\",\n", + " \"\",\n", + " \"\",\n", + " \"\",\n", + " \"\",\n", + " \"\",\n", + " \"\",\n", + " \"\",\n", + " \"<-->\",\n", + " \";\",\n", + " \"\",\n", + " \"(0,0)\",\n", + " \"(0,1)\",\n", + " \"(1,0)\",\n", + " \"(1,1)\",\n", + " \"(0,2)\",\n", + " \"(2,0)\",\n", + " \"(1,2)\",\n", + " \"(2,1)\",\n", + " \"(2,2)\",\n", + " \"(0,3)\",\n", + " \"(3,0)\",\n", + " \"(3,1)\",\n", + " \"(2,3)\",\n", + " \"(3,2)\",\n", + " \"(1,3)\",\n", + " \"(3,3)\",\n", + " \"(0,4)\",\n", + " \"(2,4)\",\n", + " \"(4,0)\",\n", + " \"(1,4)\",\n", + " \"(4,1)\",\n", + " \"(4,2)\",\n", + " \"(3,4)\",\n", + " \"(4,3)\",\n", + " \"(4,4)\",\n", + " \"(0,5)\",\n", + " \"(5,0)\",\n", + " \"(5,1)\",\n", + " \"(2,5)\",\n", + " \"(5,2)\",\n", + " \"(5,3)\",\n", + " \"(4,5)\",\n", + " \"(5,4)\",\n", + " \"(1,5)\",\n", + " \"(3,5)\",\n", + " \"(5,5)\"\n", + " ],\n", + " \"tokenizer_map\": {\n", + " \"\": 0,\n", + " \"\": 1,\n", + " \"\": 2,\n", + " \"\": 3,\n", + " \"\": 4,\n", + " \"\": 5,\n", + " \"\": 6,\n", + " \"\": 7,\n", + " \"<-->\": 8,\n", + " \";\": 9,\n", + " \"\": 10,\n", + " \"(0,0)\": 11,\n", + " \"(0,1)\": 12,\n", + " \"(1,0)\": 13,\n", + " \"(1,1)\": 14,\n", + " \"(0,2)\": 15,\n", + " \"(2,0)\": 16,\n", + " \"(1,2)\": 17,\n", + " \"(2,1)\": 18,\n", + " \"(2,2)\": 19,\n", + " \"(0,3)\": 20,\n", + " \"(3,0)\": 21,\n", + " \"(3,1)\": 22,\n", + " \"(2,3)\": 23,\n", + " \"(3,2)\": 24,\n", + " \"(1,3)\": 25,\n", + " \"(3,3)\": 26,\n", + " \"(0,4)\": 27,\n", + " \"(2,4)\": 28,\n", + " \"(4,0)\": 29,\n", + " \"(1,4)\": 30,\n", + " \"(4,1)\": 31,\n", + " \"(4,2)\": 32,\n", + " \"(3,4)\": 33,\n", + " \"(4,3)\": 34,\n", + " \"(4,4)\": 35,\n", + " \"(0,5)\": 36,\n", + " \"(5,0)\": 37,\n", + " \"(5,1)\": 38,\n", + " \"(2,5)\": 39,\n", + " \"(5,2)\": 40,\n", + " \"(5,3)\": 41,\n", + " \"(4,5)\": 42,\n", + " \"(5,4)\": 43,\n", + " \"(1,5)\": 44,\n", + " \"(3,5)\": 45,\n", + " \"(5,5)\": 46\n", + " },\n", + " \"vocab_size\": 47,\n", + " \"padding_token_index\": 10\n", + " },\n", + " \"_tokenizer\": \"None\"\n", + " }\n", + " }\n", + "}\n", + "Summary logged, getting dataset\n", + "passed dataset has matching config, using that\n", + "finished getting training dataset with 10000 samples\n", + "Loaded 10000 sequences\n", + "Creating dataloader\n", + "finished dataloader, passing to train()\n", + "Initializing model\n", + "Moving model to device: cpu\n", + "{\n", + " \"device\": \"cpu\",\n", + " \"model.device\": \"cpu\"\n", + "}\n", + "Initializing optimizer\n", + "{\n", + " \"model_n_params\": 1536\n", + "}\n", + "{\n", + " \"n_batches\": 625,\n", + " \"n_samples\": 10000,\n", + " \"intervals\": {\n", + " \"print_loss\": 6,\n", + " \"checkpoint\": 125,\n", + " \"eval_fast\": Infinity,\n", + " \"eval_slow\": Infinity\n", + " }\n", + "}\n", + "will train for 625 batches, evals_enabled=False, with intervals: {'print_loss': 6, 'checkpoint': 125, 'eval_fast': inf, 'eval_slow': inf}\n", + "Starting training\n", + "{\n", + " \"loss\": 4.145845413208008\n", + "}\n", + "iteration 0/625: loss=4.146\n", + "Saving model checkpoint to ../data/custom_2024-01-26-17-07-30/checkpoints/model.iter_0.zanj\n", + "../data/custom_2024-01-26-17-07-30/checkpoints/model.iter_0.zanj\n", + "{\n", + " \"loss\": 4.090298652648926\n", + "}\n", + "{\n", + " \"loss\": 4.042042255401611\n", + "}\n", + "{\n", + " \"loss\": 4.065291881561279\n", + "}\n", + "{\n", + " \"loss\": 4.051484107971191\n", + "}\n", + "{\n", + " \"loss\": 4.080485820770264\n", + "}\n", + "{\n", + " \"loss\": 4.032511234283447\n", + "}\n", + "iteration 6/625: loss=4.033\n", + "{\n", + " \"loss\": 4.0358500480651855\n", + "}\n", + "{\n", + " \"loss\": 4.030624866485596\n", + "}\n", + "{\n", + " \"loss\": 4.025818824768066\n", + "}\n", + "{\n", + " \"loss\": 4.013055324554443\n", + "}\n", + "{\n", + " \"loss\": 4.024381637573242\n", + "}\n", + "{\n", + " \"loss\": 3.9660725593566895\n", + "}\n", + "iteration 12/625: loss=3.966\n", + "{\n", + " \"loss\": 4.025335311889648\n", + "}\n", + "{\n", + " \"loss\": 3.976882219314575\n", + "}\n", + "{\n", + " \"loss\": 3.9642093181610107\n", + "}\n", + "{\n", + " \"loss\": 3.987898349761963\n", + "}\n", + "{\n", + " \"loss\": 3.9674484729766846\n", + "}\n", + "{\n", + " \"loss\": 3.9408040046691895\n", + "}\n", + "iteration 18/625: loss=3.941\n", + "{\n", + " \"loss\": 3.9355008602142334\n", + "}\n", + "{\n", + " \"loss\": 3.9518754482269287\n", + "}\n", + "{\n", + " \"loss\": 3.990654945373535\n", + "}\n", + "{\n", + " \"loss\": 3.942714214324951\n", + "}\n", + "{\n", + " \"loss\": 3.9597270488739014\n", + "}\n", + "{\n", + " \"loss\": 3.958251953125\n", + "}\n", + "iteration 24/625: loss=3.958\n", + "{\n", + " \"loss\": 3.9349920749664307\n", + "}\n", + "{\n", + " \"loss\": 3.9327073097229004\n", + "}\n", + "{\n", + " \"loss\": 3.9494831562042236\n", + "}\n", + "{\n", + " \"loss\": 3.947068452835083\n", + "}\n", + "{\n", + " \"loss\": 3.9052605628967285\n", + "}\n", + "{\n", + " \"loss\": 3.8726985454559326\n", + "}\n", + "iteration 30/625: loss=3.873\n", + "{\n", + " \"loss\": 3.9395759105682373\n", + "}\n", + "{\n", + " \"loss\": 3.908691167831421\n", + "}\n", + "{\n", + " \"loss\": 3.9320783615112305\n", + "}\n", + "{\n", + " \"loss\": 3.9372875690460205\n", + "}\n", + "{\n", + " \"loss\": 3.929636240005493\n", + "}\n", + "{\n", + " \"loss\": 3.905517578125\n", + "}\n", + "iteration 36/625: loss=3.906\n", + "{\n", + " \"loss\": 3.9328231811523438\n", + "}\n", + "{\n", + " \"loss\": 3.904160499572754\n", + "}\n", + "{\n", + " \"loss\": 3.90518856048584\n", + "}\n", + "{\n", + " \"loss\": 3.9624059200286865\n", + "}\n", + "{\n", + " \"loss\": 3.8829054832458496\n", + "}\n", + "{\n", + " \"loss\": 3.8683793544769287\n", + "}\n", + "iteration 42/625: loss=3.868\n", + "{\n", + " \"loss\": 3.9059863090515137\n", + "}\n", + "{\n", + " \"loss\": 3.895970106124878\n", + "}\n", + "{\n", + " \"loss\": 3.89255690574646\n", + "}\n", + "{\n", + " \"loss\": 3.9100496768951416\n", + "}\n", + "{\n", + " \"loss\": 3.868455648422241\n", + "}\n", + "{\n", + " \"loss\": 3.876054048538208\n", + "}\n", + "iteration 48/625: loss=3.876\n", + "{\n", + " \"loss\": 3.868861675262451\n", + "}\n", + "{\n", + " \"loss\": 3.878199577331543\n", + "}\n", + "{\n", + " \"loss\": 3.8640007972717285\n", + "}\n", + "{\n", + " \"loss\": 3.8807713985443115\n", + "}\n", + "{\n", + " \"loss\": 3.823042392730713\n", + "}\n", + "{\n", + " \"loss\": 3.8687191009521484\n", + "}\n", + "iteration 54/625: loss=3.869\n", + "{\n", + " \"loss\": 3.8719592094421387\n", + "}\n", + "{\n", + " \"loss\": 3.8490495681762695\n", + "}\n", + "{\n", + " \"loss\": 3.867501735687256\n", + "}\n", + "{\n", + " \"loss\": 3.8567469120025635\n", + "}\n", + "{\n", + " \"loss\": 3.8736319541931152\n", + "}\n", + "{\n", + " \"loss\": 3.896731376647949\n", + "}\n", + "iteration 60/625: loss=3.897\n", + "{\n", + " \"loss\": 3.865617275238037\n", + "}\n", + "{\n", + " \"loss\": 3.8121097087860107\n", + "}\n", + "{\n", + " \"loss\": 3.8652961254119873\n", + "}\n", + "{\n", + " \"loss\": 3.8064985275268555\n", + "}\n", + "{\n", + " \"loss\": 3.866581678390503\n", + "}\n", + "{\n", + " \"loss\": 3.8052070140838623\n", + "}\n", + "iteration 66/625: loss=3.805\n", + "{\n", + " \"loss\": 3.8451247215270996\n", + "}\n", + "{\n", + " \"loss\": 3.846853733062744\n", + "}\n", + "{\n", + " \"loss\": 3.8454911708831787\n", + "}\n", + "{\n", + " \"loss\": 3.8833065032958984\n", + "}\n", + "{\n", + " \"loss\": 3.7974774837493896\n", + "}\n", + "{\n", + " \"loss\": 3.81105637550354\n", + "}\n", + "iteration 72/625: loss=3.811\n", + "{\n", + " \"loss\": 3.783374309539795\n", + "}\n", + "{\n", + " \"loss\": 3.9005792140960693\n", + "}\n", + "{\n", + " \"loss\": 3.8105275630950928\n", + "}\n", + "{\n", + " \"loss\": 3.8295047283172607\n", + "}\n", + "{\n", + " \"loss\": 3.8742551803588867\n", + "}\n", + "{\n", + " \"loss\": 3.8290231227874756\n", + "}\n", + "iteration 78/625: loss=3.829\n", + "{\n", + " \"loss\": 3.800978183746338\n", + "}\n", + "{\n", + " \"loss\": 3.7828478813171387\n", + "}\n", + "{\n", + " \"loss\": 3.7965502738952637\n", + "}\n", + "{\n", + " \"loss\": 3.828066110610962\n", + "}\n", + "{\n", + " \"loss\": 3.810558795928955\n", + "}\n", + "{\n", + " \"loss\": 3.7995452880859375\n", + "}\n", + "iteration 84/625: loss=3.800\n", + "{\n", + " \"loss\": 3.838773250579834\n", + "}\n", + "{\n", + " \"loss\": 3.7926597595214844\n", + "}\n", + "{\n", + " \"loss\": 3.8162167072296143\n", + "}\n", + "{\n", + " \"loss\": 3.813662528991699\n", + "}\n", + "{\n", + " \"loss\": 3.7697629928588867\n", + "}\n", + "{\n", + " \"loss\": 3.8062360286712646\n", + "}\n", + "iteration 90/625: loss=3.806\n", + "{\n", + " \"loss\": 3.806680917739868\n", + "}\n", + "{\n", + " \"loss\": 3.8160486221313477\n", + "}\n", + "{\n", + " \"loss\": 3.7416698932647705\n", + "}\n", + "{\n", + " \"loss\": 3.7349443435668945\n", + "}\n", + "{\n", + " \"loss\": 3.790437936782837\n", + "}\n", + "{\n", + " \"loss\": 3.7562358379364014\n", + "}\n", + "iteration 96/625: loss=3.756\n", + "{\n", + " \"loss\": 3.7964577674865723\n", + "}\n", + "{\n", + " \"loss\": 3.766005039215088\n", + "}\n", + "{\n", + " \"loss\": 3.7669262886047363\n", + "}\n", + "{\n", + " \"loss\": 3.8161680698394775\n", + "}\n", + "{\n", + " \"loss\": 3.702213764190674\n", + "}\n", + "{\n", + " \"loss\": 3.7755918502807617\n", + "}\n", + "iteration 102/625: loss=3.776\n", + "{\n", + " \"loss\": 3.8101937770843506\n", + "}\n", + "{\n", + " \"loss\": 3.7848360538482666\n", + "}\n", + "{\n", + " \"loss\": 3.768131732940674\n", + "}\n", + "{\n", + " \"loss\": 3.762727975845337\n", + "}\n", + "{\n", + " \"loss\": 3.754617929458618\n", + "}\n", + "{\n", + " \"loss\": 3.749685764312744\n", + "}\n", + "iteration 108/625: loss=3.750\n", + "{\n", + " \"loss\": 3.776214361190796\n", + "}\n", + "{\n", + " \"loss\": 3.7450432777404785\n", + "}\n", + "{\n", + " \"loss\": 3.771043062210083\n", + "}\n", + "{\n", + " \"loss\": 3.7456398010253906\n", + "}\n", + "{\n", + " \"loss\": 3.7395594120025635\n", + "}\n", + "{\n", + " \"loss\": 3.7372689247131348\n", + "}\n", + "iteration 114/625: loss=3.737\n", + "{\n", + " \"loss\": 3.776867628097534\n", + "}\n", + "{\n", + " \"loss\": 3.7638683319091797\n", + "}\n", + "{\n", + " \"loss\": 3.7403507232666016\n", + "}\n", + "{\n", + " \"loss\": 3.7880120277404785\n", + "}\n", + "{\n", + " \"loss\": 3.770235300064087\n", + "}\n", + "{\n", + " \"loss\": 3.771207094192505\n", + "}\n", + "iteration 120/625: loss=3.771\n", + "{\n", + " \"loss\": 3.7122116088867188\n", + "}\n", + "{\n", + " \"loss\": 3.725720167160034\n", + "}\n", + "{\n", + " \"loss\": 3.7478432655334473\n", + "}\n", + "{\n", + " \"loss\": 3.7195987701416016\n", + "}\n", + "{\n", + " \"loss\": 3.7561354637145996\n", + "}\n", + "Saving model checkpoint to ../data/custom_2024-01-26-17-07-30/checkpoints/model.iter_125.zanj\n", + "../data/custom_2024-01-26-17-07-30/checkpoints/model.iter_125.zanj\n", + "{\n", + " \"loss\": 3.725956439971924\n", + "}\n", + "iteration 126/625: loss=3.726\n", + "{\n", + " \"loss\": 3.75323486328125\n", + "}\n", + "{\n", + " \"loss\": 3.7052669525146484\n", + "}\n", + "{\n", + " \"loss\": 3.762143135070801\n", + "}\n", + "{\n", + " \"loss\": 3.7057151794433594\n", + "}\n", + "{\n", + " \"loss\": 3.725200891494751\n", + "}\n", + "{\n", + " \"loss\": 3.6983983516693115\n", + "}\n", + "iteration 132/625: loss=3.698\n", + "{\n", + " \"loss\": 3.779067277908325\n", + "}\n", + "{\n", + " \"loss\": 3.7259557247161865\n", + "}\n", + "{\n", + " \"loss\": 3.728921890258789\n", + "}\n", + "{\n", + " \"loss\": 3.731539249420166\n", + "}\n", + "{\n", + " \"loss\": 3.70192813873291\n", + "}\n", + "{\n", + " \"loss\": 3.7398624420166016\n", + "}\n", + "iteration 138/625: loss=3.740\n", + "{\n", + " \"loss\": 3.725560188293457\n", + "}\n", + "{\n", + " \"loss\": 3.7320871353149414\n", + "}\n", + "{\n", + " \"loss\": 3.711059808731079\n", + "}\n", + "{\n", + " \"loss\": 3.6905457973480225\n", + "}\n", + "{\n", + " \"loss\": 3.7459778785705566\n", + "}\n", + "{\n", + " \"loss\": 3.728306770324707\n", + "}\n", + "iteration 144/625: loss=3.728\n", + "{\n", + " \"loss\": 3.678739309310913\n", + "}\n", + "{\n", + " \"loss\": 3.7131574153900146\n", + "}\n", + "{\n", + " \"loss\": 3.719914674758911\n", + "}\n", + "{\n", + " \"loss\": 3.71771240234375\n", + "}\n", + "{\n", + " \"loss\": 3.7404749393463135\n", + "}\n", + "{\n", + " \"loss\": 3.720479726791382\n", + "}\n", + "iteration 150/625: loss=3.720\n", + "{\n", + " \"loss\": 3.6947784423828125\n", + "}\n", + "{\n", + " \"loss\": 3.672067403793335\n", + "}\n", + "{\n", + " \"loss\": 3.714975118637085\n", + "}\n", + "{\n", + " \"loss\": 3.664724111557007\n", + "}\n", + "{\n", + " \"loss\": 3.6324193477630615\n", + "}\n", + "{\n", + " \"loss\": 3.6797707080841064\n", + "}\n", + "iteration 156/625: loss=3.680\n", + "{\n", + " \"loss\": 3.6846160888671875\n", + "}\n", + "{\n", + " \"loss\": 3.704012870788574\n", + "}\n", + "{\n", + " \"loss\": 3.6912505626678467\n", + "}\n", + "{\n", + " \"loss\": 3.6645476818084717\n", + "}\n", + "{\n", + " \"loss\": 3.710772752761841\n", + "}\n", + "{\n", + " \"loss\": 3.6693687438964844\n", + "}\n", + "iteration 162/625: loss=3.669\n", + "{\n", + " \"loss\": 3.6792848110198975\n", + "}\n", + "{\n", + " \"loss\": 3.6849517822265625\n", + "}\n", + "{\n", + " \"loss\": 3.685225248336792\n", + "}\n", + "{\n", + " \"loss\": 3.6591010093688965\n", + "}\n", + "{\n", + " \"loss\": 3.653116464614868\n", + "}\n", + "{\n", + " \"loss\": 3.6980695724487305\n", + "}\n", + "iteration 168/625: loss=3.698\n", + "{\n", + " \"loss\": 3.689469814300537\n", + "}\n", + "{\n", + " \"loss\": 3.6920344829559326\n", + "}\n", + "{\n", + " \"loss\": 3.6612255573272705\n", + "}\n", + "{\n", + " \"loss\": 3.6608269214630127\n", + "}\n", + "{\n", + " \"loss\": 3.6531434059143066\n", + "}\n", + "{\n", + " \"loss\": 3.675853967666626\n", + "}\n", + "iteration 174/625: loss=3.676\n", + "{\n", + " \"loss\": 3.6956088542938232\n", + "}\n", + "{\n", + " \"loss\": 3.7043049335479736\n", + "}\n", + "{\n", + " \"loss\": 3.7047536373138428\n", + "}\n", + "{\n", + " \"loss\": 3.666978597640991\n", + "}\n", + "{\n", + " \"loss\": 3.6446926593780518\n", + "}\n", + "{\n", + " \"loss\": 3.685283899307251\n", + "}\n", + "iteration 180/625: loss=3.685\n", + "{\n", + " \"loss\": 3.6662838459014893\n", + "}\n", + "{\n", + " \"loss\": 3.6601884365081787\n", + "}\n", + "{\n", + " \"loss\": 3.6706175804138184\n", + "}\n", + "{\n", + " \"loss\": 3.6529171466827393\n", + "}\n", + "{\n", + " \"loss\": 3.6476707458496094\n", + "}\n", + "{\n", + " \"loss\": 3.6399025917053223\n", + "}\n", + "iteration 186/625: loss=3.640\n", + "{\n", + " \"loss\": 3.634774684906006\n", + "}\n", + "{\n", + " \"loss\": 3.6840593814849854\n", + "}\n", + "{\n", + " \"loss\": 3.6319735050201416\n", + "}\n", + "{\n", + " \"loss\": 3.625915288925171\n", + "}\n", + "{\n", + " \"loss\": 3.682969570159912\n", + "}\n", + "{\n", + " \"loss\": 3.610685348510742\n", + "}\n", + "iteration 192/625: loss=3.611\n", + "{\n", + " \"loss\": 3.642317771911621\n", + "}\n", + "{\n", + " \"loss\": 3.6446373462677\n", + "}\n", + "{\n", + " \"loss\": 3.646674633026123\n", + "}\n", + "{\n", + " \"loss\": 3.635993719100952\n", + "}\n", + "{\n", + " \"loss\": 3.673973798751831\n", + "}\n", + "{\n", + " \"loss\": 3.610124349594116\n", + "}\n", + "iteration 198/625: loss=3.610\n", + "{\n", + " \"loss\": 3.658254623413086\n", + "}\n", + "{\n", + " \"loss\": 3.598567008972168\n", + "}\n", + "{\n", + " \"loss\": 3.606065273284912\n", + "}\n", + "{\n", + " \"loss\": 3.6174495220184326\n", + "}\n", + "{\n", + " \"loss\": 3.629673719406128\n", + "}\n", + "{\n", + " \"loss\": 3.6077795028686523\n", + "}\n", + "iteration 204/625: loss=3.608\n", + "{\n", + " \"loss\": 3.591167449951172\n", + "}\n", + "{\n", + " \"loss\": 3.6513991355895996\n", + "}\n", + "{\n", + " \"loss\": 3.6198103427886963\n", + "}\n", + "{\n", + " \"loss\": 3.650005340576172\n", + "}\n", + "{\n", + " \"loss\": 3.6053307056427\n", + "}\n", + "{\n", + " \"loss\": 3.6474227905273438\n", + "}\n", + "iteration 210/625: loss=3.647\n", + "{\n", + " \"loss\": 3.5669491291046143\n", + "}\n", + "{\n", + " \"loss\": 3.642470121383667\n", + "}\n", + "{\n", + " \"loss\": 3.5641608238220215\n", + "}\n", + "{\n", + " \"loss\": 3.6208417415618896\n", + "}\n", + "{\n", + " \"loss\": 3.650099277496338\n", + "}\n", + "{\n", + " \"loss\": 3.585646867752075\n", + "}\n", + "iteration 216/625: loss=3.586\n", + "{\n", + " \"loss\": 3.604459524154663\n", + "}\n", + "{\n", + " \"loss\": 3.5727243423461914\n", + "}\n", + "{\n", + " \"loss\": 3.617661476135254\n", + "}\n", + "{\n", + " \"loss\": 3.5823872089385986\n", + "}\n", + "{\n", + " \"loss\": 3.5778603553771973\n", + "}\n", + "{\n", + " \"loss\": 3.572519540786743\n", + "}\n", + "iteration 222/625: loss=3.573\n", + "{\n", + " \"loss\": 3.5884761810302734\n", + "}\n", + "{\n", + " \"loss\": 3.618865728378296\n", + "}\n", + "{\n", + " \"loss\": 3.5182416439056396\n", + "}\n", + "{\n", + " \"loss\": 3.5659565925598145\n", + "}\n", + "{\n", + " \"loss\": 3.555257558822632\n", + "}\n", + "{\n", + " \"loss\": 3.5684943199157715\n", + "}\n", + "iteration 228/625: loss=3.568\n", + "{\n", + " \"loss\": 3.599360942840576\n", + "}\n", + "{\n", + " \"loss\": 3.5706636905670166\n", + "}\n", + "{\n", + " \"loss\": 3.564620018005371\n", + "}\n", + "{\n", + " \"loss\": 3.6170847415924072\n", + "}\n", + "{\n", + " \"loss\": 3.5960912704467773\n", + "}\n", + "{\n", + " \"loss\": 3.5908167362213135\n", + "}\n", + "iteration 234/625: loss=3.591\n", + "{\n", + " \"loss\": 3.6147122383117676\n", + "}\n", + "{\n", + " \"loss\": 3.6076581478118896\n", + "}\n", + "{\n", + " \"loss\": 3.564779043197632\n", + "}\n", + "{\n", + " \"loss\": 3.63839054107666\n", + "}\n", + "{\n", + " \"loss\": 3.563035488128662\n", + "}\n", + "{\n", + " \"loss\": 3.587372064590454\n", + "}\n", + "iteration 240/625: loss=3.587\n", + "{\n", + " \"loss\": 3.6065943241119385\n", + "}\n", + "{\n", + " \"loss\": 3.5902259349823\n", + "}\n", + "{\n", + " \"loss\": 3.5971500873565674\n", + "}\n", + "{\n", + " \"loss\": 3.5634138584136963\n", + "}\n", + "{\n", + " \"loss\": 3.567080020904541\n", + "}\n", + "{\n", + " \"loss\": 3.5710599422454834\n", + "}\n", + "iteration 246/625: loss=3.571\n", + "{\n", + " \"loss\": 3.5993292331695557\n", + "}\n", + "{\n", + " \"loss\": 3.6017613410949707\n", + "}\n", + "{\n", + " \"loss\": 3.5757882595062256\n", + "}\n", + "{\n", + " \"loss\": 3.5465891361236572\n", + "}\n", + "Saving model checkpoint to ../data/custom_2024-01-26-17-07-30/checkpoints/model.iter_250.zanj\n", + "../data/custom_2024-01-26-17-07-30/checkpoints/model.iter_250.zanj\n", + "{\n", + " \"loss\": 3.5793814659118652\n", + "}\n", + "{\n", + " \"loss\": 3.582940101623535\n", + "}\n", + "iteration 252/625: loss=3.583\n", + "{\n", + " \"loss\": 3.552014112472534\n", + "}\n", + "{\n", + " \"loss\": 3.571335792541504\n", + "}\n", + "{\n", + " \"loss\": 3.5893428325653076\n", + "}\n", + "{\n", + " \"loss\": 3.574636220932007\n", + "}\n", + "{\n", + " \"loss\": 3.559948682785034\n", + "}\n", + "{\n", + " \"loss\": 3.5834758281707764\n", + "}\n", + "iteration 258/625: loss=3.583\n", + "{\n", + " \"loss\": 3.554215908050537\n", + "}\n", + "{\n", + " \"loss\": 3.558847427368164\n", + "}\n", + "{\n", + " \"loss\": 3.584425449371338\n", + "}\n", + "{\n", + " \"loss\": 3.543684959411621\n", + "}\n", + "{\n", + " \"loss\": 3.5651254653930664\n", + "}\n", + "{\n", + " \"loss\": 3.5355935096740723\n", + "}\n", + "iteration 264/625: loss=3.536\n", + "{\n", + " \"loss\": 3.523069143295288\n", + "}\n", + "{\n", + " \"loss\": 3.5814030170440674\n", + "}\n", + "{\n", + " \"loss\": 3.5656020641326904\n", + "}\n", + "{\n", + " \"loss\": 3.5658321380615234\n", + "}\n", + "{\n", + " \"loss\": 3.492781639099121\n", + "}\n", + "{\n", + " \"loss\": 3.5331077575683594\n", + "}\n", + "iteration 270/625: loss=3.533\n", + "{\n", + " \"loss\": 3.5576202869415283\n", + "}\n", + "{\n", + " \"loss\": 3.5550408363342285\n", + "}\n", + "{\n", + " \"loss\": 3.535871744155884\n", + "}\n", + "{\n", + " \"loss\": 3.5877583026885986\n", + "}\n", + "{\n", + " \"loss\": 3.5213778018951416\n", + "}\n", + "{\n", + " \"loss\": 3.534348249435425\n", + "}\n", + "iteration 276/625: loss=3.534\n", + "{\n", + " \"loss\": 3.5657341480255127\n", + "}\n", + "{\n", + " \"loss\": 3.4933602809906006\n", + "}\n", + "{\n", + " \"loss\": 3.5339643955230713\n", + "}\n", + "{\n", + " \"loss\": 3.5067875385284424\n", + "}\n", + "{\n", + " \"loss\": 3.5079421997070312\n", + "}\n", + "{\n", + " \"loss\": 3.522272825241089\n", + "}\n", + "iteration 282/625: loss=3.522\n", + "{\n", + " \"loss\": 3.504040002822876\n", + "}\n", + "{\n", + " \"loss\": 3.5006015300750732\n", + "}\n", + "{\n", + " \"loss\": 3.549135684967041\n", + "}\n", + "{\n", + " \"loss\": 3.5503053665161133\n", + "}\n", + "{\n", + " \"loss\": 3.538738965988159\n", + "}\n", + "{\n", + " \"loss\": 3.497903347015381\n", + "}\n", + "iteration 288/625: loss=3.498\n", + "{\n", + " \"loss\": 3.533748149871826\n", + "}\n", + "{\n", + " \"loss\": 3.5166547298431396\n", + "}\n", + "{\n", + " \"loss\": 3.537839651107788\n", + "}\n", + "{\n", + " \"loss\": 3.5503244400024414\n", + "}\n", + "{\n", + " \"loss\": 3.474510431289673\n", + "}\n", + "{\n", + " \"loss\": 3.534044027328491\n", + "}\n", + "iteration 294/625: loss=3.534\n", + "{\n", + " \"loss\": 3.4889495372772217\n", + "}\n", + "{\n", + " \"loss\": 3.473132371902466\n", + "}\n", + "{\n", + " \"loss\": 3.5154452323913574\n", + "}\n", + "{\n", + " \"loss\": 3.557323455810547\n", + "}\n", + "{\n", + " \"loss\": 3.5052053928375244\n", + "}\n", + "{\n", + " \"loss\": 3.466235399246216\n", + "}\n", + "iteration 300/625: loss=3.466\n", + "{\n", + " \"loss\": 3.490253448486328\n", + "}\n", + "{\n", + " \"loss\": 3.503431558609009\n", + "}\n", + "{\n", + " \"loss\": 3.514857769012451\n", + "}\n", + "{\n", + " \"loss\": 3.489647388458252\n", + "}\n", + "{\n", + " \"loss\": 3.5147719383239746\n", + "}\n", + "{\n", + " \"loss\": 3.565143346786499\n", + "}\n", + "iteration 306/625: loss=3.565\n", + "{\n", + " \"loss\": 3.54084849357605\n", + "}\n", + "{\n", + " \"loss\": 3.456279993057251\n", + "}\n", + "{\n", + " \"loss\": 3.505070209503174\n", + "}\n", + "{\n", + " \"loss\": 3.535806179046631\n", + "}\n", + "{\n", + " \"loss\": 3.5115747451782227\n", + "}\n", + "{\n", + " \"loss\": 3.4868581295013428\n", + "}\n", + "iteration 312/625: loss=3.487\n", + "{\n", + " \"loss\": 3.517411231994629\n", + "}\n", + "{\n", + " \"loss\": 3.491832733154297\n", + "}\n", + "{\n", + " \"loss\": 3.497812271118164\n", + "}\n", + "{\n", + " \"loss\": 3.4866867065429688\n", + "}\n", + "{\n", + " \"loss\": 3.4111194610595703\n", + "}\n", + "{\n", + " \"loss\": 3.50197172164917\n", + "}\n", + "iteration 318/625: loss=3.502\n", + "{\n", + " \"loss\": 3.480678081512451\n", + "}\n", + "{\n", + " \"loss\": 3.5105457305908203\n", + "}\n", + "{\n", + " \"loss\": 3.504516363143921\n", + "}\n", + "{\n", + " \"loss\": 3.4774861335754395\n", + "}\n", + "{\n", + " \"loss\": 3.497036933898926\n", + "}\n", + "{\n", + " \"loss\": 3.4337422847747803\n", + "}\n", + "iteration 324/625: loss=3.434\n", + "{\n", + " \"loss\": 3.452423334121704\n", + "}\n", + "{\n", + " \"loss\": 3.4866013526916504\n", + "}\n", + "{\n", + " \"loss\": 3.5163846015930176\n", + "}\n", + "{\n", + " \"loss\": 3.49133563041687\n", + "}\n", + "{\n", + " \"loss\": 3.5247466564178467\n", + "}\n", + "{\n", + " \"loss\": 3.447157144546509\n", + "}\n", + "iteration 330/625: loss=3.447\n", + "{\n", + " \"loss\": 3.5415685176849365\n", + "}\n", + "{\n", + " \"loss\": 3.4548070430755615\n", + "}\n", + "{\n", + " \"loss\": 3.509962558746338\n", + "}\n", + "{\n", + " \"loss\": 3.473855495452881\n", + "}\n", + "{\n", + " \"loss\": 3.482057809829712\n", + "}\n", + "{\n", + " \"loss\": 3.481217861175537\n", + "}\n", + "iteration 336/625: loss=3.481\n", + "{\n", + " \"loss\": 3.4592766761779785\n", + "}\n", + "{\n", + " \"loss\": 3.494074821472168\n", + "}\n", + "{\n", + " \"loss\": 3.468418836593628\n", + "}\n", + "{\n", + " \"loss\": 3.498950719833374\n", + "}\n", + "{\n", + " \"loss\": 3.4550371170043945\n", + "}\n", + "{\n", + " \"loss\": 3.4814453125\n", + "}\n", + "iteration 342/625: loss=3.481\n", + "{\n", + " \"loss\": 3.483079433441162\n", + "}\n", + "{\n", + " \"loss\": 3.490959882736206\n", + "}\n", + "{\n", + " \"loss\": 3.4057750701904297\n", + "}\n", + "{\n", + " \"loss\": 3.487997055053711\n", + "}\n", + "{\n", + " \"loss\": 3.4428021907806396\n", + "}\n", + "{\n", + " \"loss\": 3.477642297744751\n", + "}\n", + "iteration 348/625: loss=3.478\n", + "{\n", + " \"loss\": 3.474191427230835\n", + "}\n", + "{\n", + " \"loss\": 3.453618049621582\n", + "}\n", + "{\n", + " \"loss\": 3.4900968074798584\n", + "}\n", + "{\n", + " \"loss\": 3.4885547161102295\n", + "}\n", + "{\n", + " \"loss\": 3.422675132751465\n", + "}\n", + "{\n", + " \"loss\": 3.5005037784576416\n", + "}\n", + "iteration 354/625: loss=3.501\n", + "{\n", + " \"loss\": 3.4255549907684326\n", + "}\n", + "{\n", + " \"loss\": 3.4251420497894287\n", + "}\n", + "{\n", + " \"loss\": 3.449300765991211\n", + "}\n", + "{\n", + " \"loss\": 3.4667623043060303\n", + "}\n", + "{\n", + " \"loss\": 3.422987222671509\n", + "}\n", + "{\n", + " \"loss\": 3.4083847999572754\n", + "}\n", + "iteration 360/625: loss=3.408\n", + "{\n", + " \"loss\": 3.3724794387817383\n", + "}\n", + "{\n", + " \"loss\": 3.408320665359497\n", + "}\n", + "{\n", + " \"loss\": 3.4054837226867676\n", + "}\n", + "{\n", + " \"loss\": 3.3907241821289062\n", + "}\n", + "{\n", + " \"loss\": 3.466801404953003\n", + "}\n", + "{\n", + " \"loss\": 3.4486517906188965\n", + "}\n", + "iteration 366/625: loss=3.449\n", + "{\n", + " \"loss\": 3.4531421661376953\n", + "}\n", + "{\n", + " \"loss\": 3.452338218688965\n", + "}\n", + "{\n", + " \"loss\": 3.4136343002319336\n", + "}\n", + "{\n", + " \"loss\": 3.4186222553253174\n", + "}\n", + "{\n", + " \"loss\": 3.3868367671966553\n", + "}\n", + "{\n", + " \"loss\": 3.407825231552124\n", + "}\n", + "iteration 372/625: loss=3.408\n", + "{\n", + " \"loss\": 3.4601082801818848\n", + "}\n", + "{\n", + " \"loss\": 3.460134744644165\n", + "}\n", + "{\n", + " \"loss\": 3.4357261657714844\n", + "}\n", + "Saving model checkpoint to ../data/custom_2024-01-26-17-07-30/checkpoints/model.iter_375.zanj\n", + "../data/custom_2024-01-26-17-07-30/checkpoints/model.iter_375.zanj\n", + "{\n", + " \"loss\": 3.413161277770996\n", + "}\n", + "{\n", + " \"loss\": 3.380629062652588\n", + "}\n", + "{\n", + " \"loss\": 3.4016001224517822\n", + "}\n", + "iteration 378/625: loss=3.402\n", + "{\n", + " \"loss\": 3.4652349948883057\n", + "}\n", + "{\n", + " \"loss\": 3.4496147632598877\n", + "}\n", + "{\n", + " \"loss\": 3.4553184509277344\n", + "}\n", + "{\n", + " \"loss\": 3.370265245437622\n", + "}\n", + "{\n", + " \"loss\": 3.483074188232422\n", + "}\n", + "{\n", + " \"loss\": 3.4160759449005127\n", + "}\n", + "iteration 384/625: loss=3.416\n", + "{\n", + " \"loss\": 3.4300057888031006\n", + "}\n", + "{\n", + " \"loss\": 3.4795784950256348\n", + "}\n", + "{\n", + " \"loss\": 3.377732753753662\n", + "}\n", + "{\n", + " \"loss\": 3.404430627822876\n", + "}\n", + "{\n", + " \"loss\": 3.4231934547424316\n", + "}\n", + "{\n", + " \"loss\": 3.3738455772399902\n", + "}\n", + "iteration 390/625: loss=3.374\n", + "{\n", + " \"loss\": 3.4204890727996826\n", + "}\n", + "{\n", + " \"loss\": 3.374424695968628\n", + "}\n", + "{\n", + " \"loss\": 3.4298744201660156\n", + "}\n", + "{\n", + " \"loss\": 3.410593271255493\n", + "}\n", + "{\n", + " \"loss\": 3.4193458557128906\n", + "}\n", + "{\n", + " \"loss\": 3.448474168777466\n", + "}\n", + "iteration 396/625: loss=3.448\n", + "{\n", + " \"loss\": 3.3871846199035645\n", + "}\n", + "{\n", + " \"loss\": 3.3581504821777344\n", + "}\n", + "{\n", + " \"loss\": 3.417536735534668\n", + "}\n", + "{\n", + " \"loss\": 3.4475836753845215\n", + "}\n", + "{\n", + " \"loss\": 3.443405866622925\n", + "}\n", + "{\n", + " \"loss\": 3.3446266651153564\n", + "}\n", + "iteration 402/625: loss=3.345\n", + "{\n", + " \"loss\": 3.3935582637786865\n", + "}\n", + "{\n", + " \"loss\": 3.427907705307007\n", + "}\n", + "{\n", + " \"loss\": 3.411377191543579\n", + "}\n", + "{\n", + " \"loss\": 3.407881259918213\n", + "}\n", + "{\n", + " \"loss\": 3.4133989810943604\n", + "}\n", + "{\n", + " \"loss\": 3.3903441429138184\n", + "}\n", + "iteration 408/625: loss=3.390\n", + "{\n", + " \"loss\": 3.3988046646118164\n", + "}\n", + "{\n", + " \"loss\": 3.4096291065216064\n", + "}\n", + "{\n", + " \"loss\": 3.375596046447754\n", + "}\n", + "{\n", + " \"loss\": 3.357426643371582\n", + "}\n", + "{\n", + " \"loss\": 3.3525779247283936\n", + "}\n", + "{\n", + " \"loss\": 3.3953049182891846\n", + "}\n", + "iteration 414/625: loss=3.395\n", + "{\n", + " \"loss\": 3.403193712234497\n", + "}\n", + "{\n", + " \"loss\": 3.406960964202881\n", + "}\n", + "{\n", + " \"loss\": 3.4293837547302246\n", + "}\n", + "{\n", + " \"loss\": 3.3330914974212646\n", + "}\n", + "{\n", + " \"loss\": 3.360473871231079\n", + "}\n", + "{\n", + " \"loss\": 3.3816893100738525\n", + "}\n", + "iteration 420/625: loss=3.382\n", + "{\n", + " \"loss\": 3.4100005626678467\n", + "}\n", + "{\n", + " \"loss\": 3.3450844287872314\n", + "}\n", + "{\n", + " \"loss\": 3.37136173248291\n", + "}\n", + "{\n", + " \"loss\": 3.3751778602600098\n", + "}\n", + "{\n", + " \"loss\": 3.3192124366760254\n", + "}\n", + "{\n", + " \"loss\": 3.321544885635376\n", + "}\n", + "iteration 426/625: loss=3.322\n", + "{\n", + " \"loss\": 3.399653196334839\n", + "}\n", + "{\n", + " \"loss\": 3.377204418182373\n", + "}\n", + "{\n", + " \"loss\": 3.4096527099609375\n", + "}\n", + "{\n", + " \"loss\": 3.358639717102051\n", + "}\n", + "{\n", + " \"loss\": 3.3819382190704346\n", + "}\n", + "{\n", + " \"loss\": 3.3932108879089355\n", + "}\n", + "iteration 432/625: loss=3.393\n", + "{\n", + " \"loss\": 3.3844285011291504\n", + "}\n", + "{\n", + " \"loss\": 3.36179780960083\n", + "}\n", + "{\n", + " \"loss\": 3.362386465072632\n", + "}\n", + "{\n", + " \"loss\": 3.3509767055511475\n", + "}\n", + "{\n", + " \"loss\": 3.3593146800994873\n", + "}\n", + "{\n", + " \"loss\": 3.411891222000122\n", + "}\n", + "iteration 438/625: loss=3.412\n", + "{\n", + " \"loss\": 3.3491740226745605\n", + "}\n", + "{\n", + " \"loss\": 3.359356641769409\n", + "}\n", + "{\n", + " \"loss\": 3.337552785873413\n", + "}\n", + "{\n", + " \"loss\": 3.3312859535217285\n", + "}\n", + "{\n", + " \"loss\": 3.3474347591400146\n", + "}\n", + "{\n", + " \"loss\": 3.3581364154815674\n", + "}\n", + "iteration 444/625: loss=3.358\n", + "{\n", + " \"loss\": 3.3299036026000977\n", + "}\n", + "{\n", + " \"loss\": 3.3426268100738525\n", + "}\n", + "{\n", + " \"loss\": 3.3809268474578857\n", + "}\n", + "{\n", + " \"loss\": 3.3525519371032715\n", + "}\n", + "{\n", + " \"loss\": 3.3297677040100098\n", + "}\n", + "{\n", + " \"loss\": 3.3716702461242676\n", + "}\n", + "iteration 450/625: loss=3.372\n", + "{\n", + " \"loss\": 3.343858480453491\n", + "}\n", + "{\n", + " \"loss\": 3.361025810241699\n", + "}\n", + "{\n", + " \"loss\": 3.3318982124328613\n", + "}\n", + "{\n", + " \"loss\": 3.3817684650421143\n", + "}\n", + "{\n", + " \"loss\": 3.3639705181121826\n", + "}\n", + "{\n", + " \"loss\": 3.3484086990356445\n", + "}\n", + "iteration 456/625: loss=3.348\n", + "{\n", + " \"loss\": 3.2939934730529785\n", + "}\n", + "{\n", + " \"loss\": 3.3207147121429443\n", + "}\n", + "{\n", + " \"loss\": 3.325993537902832\n", + "}\n", + "{\n", + " \"loss\": 3.3636460304260254\n", + "}\n", + "{\n", + " \"loss\": 3.3334782123565674\n", + "}\n", + "{\n", + " \"loss\": 3.36348295211792\n", + "}\n", + "iteration 462/625: loss=3.363\n", + "{\n", + " \"loss\": 3.320568323135376\n", + "}\n", + "{\n", + " \"loss\": 3.288194417953491\n", + "}\n", + "{\n", + " \"loss\": 3.3226523399353027\n", + "}\n", + "{\n", + " \"loss\": 3.3174355030059814\n", + "}\n", + "{\n", + " \"loss\": 3.3270134925842285\n", + "}\n", + "{\n", + " \"loss\": 3.338449239730835\n", + "}\n", + "iteration 468/625: loss=3.338\n", + "{\n", + " \"loss\": 3.3544342517852783\n", + "}\n", + "{\n", + " \"loss\": 3.3198888301849365\n", + "}\n", + "{\n", + " \"loss\": 3.2892091274261475\n", + "}\n", + "{\n", + " \"loss\": 3.3591461181640625\n", + "}\n", + "{\n", + " \"loss\": 3.3585684299468994\n", + "}\n", + "{\n", + " \"loss\": 3.332545042037964\n", + "}\n", + "iteration 474/625: loss=3.333\n", + "{\n", + " \"loss\": 3.358856439590454\n", + "}\n", + "{\n", + " \"loss\": 3.3644022941589355\n", + "}\n", + "{\n", + " \"loss\": 3.2785587310791016\n", + "}\n", + "{\n", + " \"loss\": 3.342656373977661\n", + "}\n", + "{\n", + " \"loss\": 3.320143461227417\n", + "}\n", + "{\n", + " \"loss\": 3.2368345260620117\n", + "}\n", + "iteration 480/625: loss=3.237\n", + "{\n", + " \"loss\": 3.3405699729919434\n", + "}\n", + "{\n", + " \"loss\": 3.2913925647735596\n", + "}\n", + "{\n", + " \"loss\": 3.342508316040039\n", + "}\n", + "{\n", + " \"loss\": 3.3408985137939453\n", + "}\n", + "{\n", + " \"loss\": 3.3632891178131104\n", + "}\n", + "{\n", + " \"loss\": 3.2672674655914307\n", + "}\n", + "iteration 486/625: loss=3.267\n", + "{\n", + " \"loss\": 3.304842948913574\n", + "}\n", + "{\n", + " \"loss\": 3.3040428161621094\n", + "}\n", + "{\n", + " \"loss\": 3.3107056617736816\n", + "}\n", + "{\n", + " \"loss\": 3.296785593032837\n", + "}\n", + "{\n", + " \"loss\": 3.3505702018737793\n", + "}\n", + "{\n", + " \"loss\": 3.2878408432006836\n", + "}\n", + "iteration 492/625: loss=3.288\n", + "{\n", + " \"loss\": 3.2432310581207275\n", + "}\n", + "{\n", + " \"loss\": 3.283074140548706\n", + "}\n", + "{\n", + " \"loss\": 3.3396217823028564\n", + "}\n", + "{\n", + " \"loss\": 3.3540444374084473\n", + "}\n", + "{\n", + " \"loss\": 3.2355384826660156\n", + "}\n", + "{\n", + " \"loss\": 3.320284366607666\n", + "}\n", + "iteration 498/625: loss=3.320\n", + "{\n", + " \"loss\": 3.2541277408599854\n", + "}\n", + "{\n", + " \"loss\": 3.2973036766052246\n", + "}\n", + "Saving model checkpoint to ../data/custom_2024-01-26-17-07-30/checkpoints/model.iter_500.zanj\n", + "../data/custom_2024-01-26-17-07-30/checkpoints/model.iter_500.zanj\n", + "{\n", + " \"loss\": 3.2913644313812256\n", + "}\n", + "{\n", + " \"loss\": 3.289907455444336\n", + "}\n", + "{\n", + " \"loss\": 3.2824289798736572\n", + "}\n", + "{\n", + " \"loss\": 3.322770833969116\n", + "}\n", + "iteration 504/625: loss=3.323\n", + "{\n", + " \"loss\": 3.2661194801330566\n", + "}\n", + "{\n", + " \"loss\": 3.274456262588501\n", + "}\n", + "{\n", + " \"loss\": 3.3232667446136475\n", + "}\n", + "{\n", + " \"loss\": 3.231204032897949\n", + "}\n", + "{\n", + " \"loss\": 3.2885568141937256\n", + "}\n", + "{\n", + " \"loss\": 3.297837257385254\n", + "}\n", + "iteration 510/625: loss=3.298\n", + "{\n", + " \"loss\": 3.2619142532348633\n", + "}\n", + "{\n", + " \"loss\": 3.2104337215423584\n", + "}\n", + "{\n", + " \"loss\": 3.2802133560180664\n", + "}\n", + "{\n", + " \"loss\": 3.2808961868286133\n", + "}\n", + "{\n", + " \"loss\": 3.2775797843933105\n", + "}\n", + "{\n", + " \"loss\": 3.266408681869507\n", + "}\n", + "iteration 516/625: loss=3.266\n", + "{\n", + " \"loss\": 3.299525499343872\n", + "}\n", + "{\n", + " \"loss\": 3.1960971355438232\n", + "}\n", + "{\n", + " \"loss\": 3.2297277450561523\n", + "}\n", + "{\n", + " \"loss\": 3.2547783851623535\n", + "}\n", + "{\n", + " \"loss\": 3.3103413581848145\n", + "}\n", + "{\n", + " \"loss\": 3.3113367557525635\n", + "}\n", + "iteration 522/625: loss=3.311\n", + "{\n", + " \"loss\": 3.279250144958496\n", + "}\n", + "{\n", + " \"loss\": 3.2500522136688232\n", + "}\n", + "{\n", + " \"loss\": 3.31075119972229\n", + "}\n", + "{\n", + " \"loss\": 3.2506184577941895\n", + "}\n", + "{\n", + " \"loss\": 3.2806665897369385\n", + "}\n", + "{\n", + " \"loss\": 3.2879269123077393\n", + "}\n", + "iteration 528/625: loss=3.288\n", + "{\n", + " \"loss\": 3.30780291557312\n", + "}\n", + "{\n", + " \"loss\": 3.199627637863159\n", + "}\n", + "{\n", + " \"loss\": 3.2415099143981934\n", + "}\n", + "{\n", + " \"loss\": 3.3130781650543213\n", + "}\n", + "{\n", + " \"loss\": 3.2908084392547607\n", + "}\n", + "{\n", + " \"loss\": 3.2460033893585205\n", + "}\n", + "iteration 534/625: loss=3.246\n", + "{\n", + " \"loss\": 3.1655216217041016\n", + "}\n", + "{\n", + " \"loss\": 3.280019998550415\n", + "}\n", + "{\n", + " \"loss\": 3.2473294734954834\n", + "}\n", + "{\n", + " \"loss\": 3.220569610595703\n", + "}\n", + "{\n", + " \"loss\": 3.285362958908081\n", + "}\n", + "{\n", + " \"loss\": 3.256585121154785\n", + "}\n", + "iteration 540/625: loss=3.257\n", + "{\n", + " \"loss\": 3.2261807918548584\n", + "}\n", + "{\n", + " \"loss\": 3.2280471324920654\n", + "}\n", + "{\n", + " \"loss\": 3.21641206741333\n", + "}\n", + "{\n", + " \"loss\": 3.2843263149261475\n", + "}\n", + "{\n", + " \"loss\": 3.2770743370056152\n", + "}\n", + "{\n", + " \"loss\": 3.2083640098571777\n", + "}\n", + "iteration 546/625: loss=3.208\n", + "{\n", + " \"loss\": 3.23732852935791\n", + "}\n", + "{\n", + " \"loss\": 3.2782509326934814\n", + "}\n", + "{\n", + " \"loss\": 3.2230064868927\n", + "}\n", + "{\n", + " \"loss\": 3.2422330379486084\n", + "}\n", + "{\n", + " \"loss\": 3.265923023223877\n", + "}\n", + "{\n", + " \"loss\": 3.1814773082733154\n", + "}\n", + "iteration 552/625: loss=3.181\n", + "{\n", + " \"loss\": 3.2399916648864746\n", + "}\n", + "{\n", + " \"loss\": 3.238511562347412\n", + "}\n", + "{\n", + " \"loss\": 3.2494843006134033\n", + "}\n", + "{\n", + " \"loss\": 3.2363786697387695\n", + "}\n", + "{\n", + " \"loss\": 3.2375986576080322\n", + "}\n", + "{\n", + " \"loss\": 3.2542572021484375\n", + "}\n", + "iteration 558/625: loss=3.254\n", + "{\n", + " \"loss\": 3.2486767768859863\n", + "}\n", + "{\n", + " \"loss\": 3.2123193740844727\n", + "}\n", + "{\n", + " \"loss\": 3.219663381576538\n", + "}\n", + "{\n", + " \"loss\": 3.191136360168457\n", + "}\n", + "{\n", + " \"loss\": 3.2624711990356445\n", + "}\n", + "{\n", + " \"loss\": 3.1900157928466797\n", + "}\n", + "iteration 564/625: loss=3.190\n", + "{\n", + " \"loss\": 3.2229647636413574\n", + "}\n", + "{\n", + " \"loss\": 3.2395780086517334\n", + "}\n", + "{\n", + " \"loss\": 3.2095651626586914\n", + "}\n", + "{\n", + " \"loss\": 3.2494213581085205\n", + "}\n", + "{\n", + " \"loss\": 3.24226713180542\n", + "}\n", + "{\n", + " \"loss\": 3.172736644744873\n", + "}\n", + "iteration 570/625: loss=3.173\n", + "{\n", + " \"loss\": 3.208057403564453\n", + "}\n", + "{\n", + " \"loss\": 3.182804822921753\n", + "}\n", + "{\n", + " \"loss\": 3.269094228744507\n", + "}\n", + "{\n", + " \"loss\": 3.2078652381896973\n", + "}\n", + "{\n", + " \"loss\": 3.23347544670105\n", + "}\n", + "{\n", + " \"loss\": 3.2304015159606934\n", + "}\n", + "iteration 576/625: loss=3.230\n", + "{\n", + " \"loss\": 3.197012424468994\n", + "}\n", + "{\n", + " \"loss\": 3.1724846363067627\n", + "}\n", + "{\n", + " \"loss\": 3.211442232131958\n", + "}\n", + "{\n", + " \"loss\": 3.2226929664611816\n", + "}\n", + "{\n", + " \"loss\": 3.2213873863220215\n", + "}\n", + "{\n", + " \"loss\": 3.1255741119384766\n", + "}\n", + "iteration 582/625: loss=3.126\n", + "{\n", + " \"loss\": 3.2052719593048096\n", + "}\n", + "{\n", + " \"loss\": 3.176476240158081\n", + "}\n", + "{\n", + " \"loss\": 3.18984055519104\n", + "}\n", + "{\n", + " \"loss\": 3.223639965057373\n", + "}\n", + "{\n", + " \"loss\": 3.265749931335449\n", + "}\n", + "{\n", + " \"loss\": 3.2228267192840576\n", + "}\n", + "iteration 588/625: loss=3.223\n", + "{\n", + " \"loss\": 3.205139398574829\n", + "}\n", + "{\n", + " \"loss\": 3.258054733276367\n", + "}\n", + "{\n", + " \"loss\": 3.1943752765655518\n", + "}\n", + "{\n", + " \"loss\": 3.1747851371765137\n", + "}\n", + "{\n", + " \"loss\": 3.1349353790283203\n", + "}\n", + "{\n", + " \"loss\": 3.1982340812683105\n", + "}\n", + "iteration 594/625: loss=3.198\n", + "{\n", + " \"loss\": 3.1673989295959473\n", + "}\n", + "{\n", + " \"loss\": 3.1701011657714844\n", + "}\n", + "{\n", + " \"loss\": 3.1498870849609375\n", + "}\n", + "{\n", + " \"loss\": 3.1923727989196777\n", + "}\n", + "{\n", + " \"loss\": 3.1492722034454346\n", + "}\n", + "{\n", + " \"loss\": 3.173740863800049\n", + "}\n", + "iteration 600/625: loss=3.174\n", + "{\n", + " \"loss\": 3.2300286293029785\n", + "}\n", + "{\n", + " \"loss\": 3.171640634536743\n", + "}\n", + "{\n", + " \"loss\": 3.1125521659851074\n", + "}\n", + "{\n", + " \"loss\": 3.160708427429199\n", + "}\n", + "{\n", + " \"loss\": 3.110842704772949\n", + "}\n", + "{\n", + " \"loss\": 3.1477622985839844\n", + "}\n", + "iteration 606/625: loss=3.148\n", + "{\n", + " \"loss\": 3.1445677280426025\n", + "}\n", + "{\n", + " \"loss\": 3.128345489501953\n", + "}\n", + "{\n", + " \"loss\": 3.2135703563690186\n", + "}\n", + "{\n", + " \"loss\": 3.1802139282226562\n", + "}\n", + "{\n", + " \"loss\": 3.207259178161621\n", + "}\n", + "{\n", + " \"loss\": 3.143918752670288\n", + "}\n", + "iteration 612/625: loss=3.144\n", + "{\n", + " \"loss\": 3.1434335708618164\n", + "}\n", + "{\n", + " \"loss\": 3.1368257999420166\n", + "}\n", + "{\n", + " \"loss\": 3.205310344696045\n", + "}\n", + "{\n", + " \"loss\": 3.1917953491210938\n", + "}\n", + "{\n", + " \"loss\": 3.162038564682007\n", + "}\n", + "{\n", + " \"loss\": 3.2053184509277344\n", + "}\n", + "iteration 618/625: loss=3.205\n", + "{\n", + " \"loss\": 3.13569712638855\n", + "}\n", + "{\n", + " \"loss\": 3.1539762020111084\n", + "}\n", + "{\n", + " \"loss\": 3.143233299255371\n", + "}\n", + "{\n", + " \"loss\": 3.21600341796875\n", + "}\n", + "{\n", + " \"loss\": 3.1591880321502686\n", + "}\n", + "{\n", + " \"loss\": 3.147446870803833\n", + "}\n", + "iteration 624/625: loss=3.147\n", + "Saving final model to ../data/custom_2024-01-26-17-07-30/model.final.zanj\n", + "../data/custom_2024-01-26-17-07-30/model.final.zanj\n", + "Done training!\n" + ] + } + ], + "source": [ + "result: TrainingResult = train_model(\n", + "\tbase_path=PATH_DATA,\n", + " cfg=CFG,\n", + "\twandb_project=None, # change this to WandbProject.DEMO_NOTEBOOKS!\n", + "\tdo_generate_dataset=False,\n", + "\tdataset_verbose=True,\n", + " dataset=DATASET,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } From 4d56e31ce9b3c84900812fae6b14cbf327d30b81 Mon Sep 17 00:00:00 2001 From: naveenarun Date: Fri, 26 Jan 2024 18:07:22 -0600 Subject: [PATCH 3/4] lint and add unit test for training without wandb --- maze_transformer/training/train_model.py | 42 +++++++++---------- maze_transformer/training/training.py | 27 ++++++------ .../training/test_offline_training.py | 18 ++++++++ 3 files changed, 50 insertions(+), 37 deletions(-) create mode 100644 tests/unit/maze_transformer/training/test_offline_training.py diff --git a/maze_transformer/training/train_model.py b/maze_transformer/training/train_model.py index 4dab03f5..d423c33d 100644 --- a/maze_transformer/training/train_model.py +++ b/maze_transformer/training/train_model.py @@ -59,7 +59,7 @@ def train_model( - model config names: {model_cfg_names} - train config names: {train_cfg_names} """ - USES_LOGGER : bool = (wandb_project is not None) + USES_LOGGER: bool = wandb_project is not None if help: print(train_model.__doc__) @@ -87,17 +87,17 @@ def train_model( # set up logger logger_cfg_dict = dict( - logger_cfg={ - "output_dir": output_path.as_posix(), - "cfg.name": cfg.name, - "data_cfg.name": cfg.dataset_cfg.name, - "train_cfg.name": cfg.train_cfg.name, - "model_cfg.name": cfg.model_cfg.name, - "cfg_summary": cfg.summary(), - "cfg": cfg.serialize(), - }, - ) - + logger_cfg={ + "output_dir": output_path.as_posix(), + "cfg.name": cfg.name, + "data_cfg.name": cfg.dataset_cfg.name, + "train_cfg.name": cfg.train_cfg.name, + "model_cfg.name": cfg.model_cfg.name, + "cfg_summary": cfg.summary(), + "cfg": cfg.serialize(), + }, + ) + # Set up logger if wanb project is specified if USES_LOGGER: logger: WandbLogger = WandbLogger.create( @@ -108,8 +108,8 @@ def train_model( logger.progress("Initialized logger") else: logger = None - - def log(msg: str | dict, log_type: str = 'progress', **kwargs): + + def log(msg: str | dict, log_type: str = "progress", **kwargs): # Convenience function to let training routine work whether or not # logger exists if logger: @@ -121,7 +121,7 @@ def log(msg: str | dict, log_type: str = 'progress', **kwargs): else: print(msg) - log(logger_cfg_dict, log_type='summary') + log(logger_cfg_dict, log_type="summary") log("Summary logged, getting dataset") # load dataset @@ -137,15 +137,15 @@ def log(msg: str | dict, log_type: str = 'progress', **kwargs): log(f"passed dataset has matching config, using that") else: if allow_dataset_override: - log(f"passed dataset has different config than cfg.dataset_cfg, but allow_dataset_override is True, so using passed dataset") + log( + f"passed dataset has different config than cfg.dataset_cfg, but allow_dataset_override is True, so using passed dataset" + ) else: raise ValueError( f"dataset has different config than cfg.dataset_cfg, and allow_dataset_override is False" ) - log( - f"finished getting training dataset with {len(dataset)} samples" - ) + log(f"finished getting training dataset with {len(dataset)} samples") # validation dataset, if applicable val_dataset: MazeDataset | None = None @@ -178,9 +178,7 @@ def log(msg: str | dict, log_type: str = 'progress', **kwargs): local_base_path=base_path, verbose=dataset_verbose, ) - log( - f"got custom validation dataset with {len(val_dataset)} samples" - ) + log(f"got custom validation dataset with {len(val_dataset)} samples") # get dataloader and then train dataloader: DataLoader = get_dataloader(dataset, cfg, logger) diff --git a/maze_transformer/training/training.py b/maze_transformer/training/training.py index 92af184a..430ebcbc 100644 --- a/maze_transformer/training/training.py +++ b/maze_transformer/training/training.py @@ -6,8 +6,8 @@ from jaxtyping import Float from maze_dataset import MazeDataset, SolvedMaze from maze_dataset.tokenization import MazeTokenizer -from muutils.statcounter import StatCounter from muutils.mlutils import pprint_summary +from muutils.statcounter import StatCounter from torch.utils.data import DataLoader from transformer_lens.HookedTransformer import SingleLoss from zanj import ZANJ @@ -33,7 +33,7 @@ def log_progress(msg): logger.progress(msg) else: print(msg) - + if len(dataset) == 0: raise ValueError(f"Dataset is empty: {len(dataset) = }") log_progress(f"Loaded {len(dataset)} sequences") @@ -67,8 +67,7 @@ def train( zanj: ZANJ | None = None, model: ZanjHookedTransformer | None = None, ) -> ZanjHookedTransformer: - - def log(msg: str | dict, log_type: str = 'progress', **kwargs): + def log(msg: str | dict, log_type: str = "progress", **kwargs): # Convenience function to let training routine work whether or not # logger exists if logger: @@ -79,12 +78,12 @@ def log(msg: str | dict, log_type: str = 'progress', **kwargs): pprint_summary(msg) else: print(msg) - + # initialize # ============================== if zanj is None: zanj = ZANJ() - + # init model & optimizer if model is None: log(f"Initializing model") @@ -93,14 +92,14 @@ def log(msg: str | dict, log_type: str = 'progress', **kwargs): else: log("Using existing model") - log({"device": str(device), "model.device": model.cfg.device}, log_type='summary') + log({"device": str(device), "model.device": model.cfg.device}, log_type="summary") log("Initializing optimizer") optimizer: torch.optim.Optimizer = cfg.train_cfg.optimizer( model.parameters(), **cfg.train_cfg.optimizer_kwargs, ) - log(dict(model_n_params=model.cfg.n_params), log_type='summary') + log(dict(model_n_params=model.cfg.n_params), log_type="summary") # add wandb run url to model if logger: @@ -140,7 +139,7 @@ def log(msg: str | dict, log_type: str = 'progress', **kwargs): } log( {"n_batches": n_batches, "n_samples": n_samples, "intervals": intervals}, - log_type='summary' + log_type="summary", ) log( f"will train for {n_batches} batches, {evals_enabled=}, with intervals: {intervals}" @@ -186,12 +185,10 @@ def log(msg: str | dict, log_type: str = 'progress', **kwargs): max_new_tokens=cfg.train_cfg.evals_max_new_tokens, ) metrics.update(scores) - log(metrics, log_type='log_metric_hist') + log(metrics, log_type="log_metric_hist") if iteration % intervals["print_loss"] == 0: - log( - f"iteration {iteration}/{n_batches}: loss={loss.item():.3f}" - ) + log(f"iteration {iteration}/{n_batches}: loss={loss.item():.3f}") del loss @@ -207,7 +204,7 @@ def log(msg: str | dict, log_type: str = 'progress', **kwargs): zanj.save(model, model_save_path) log( model_save_path, - log_type = 'upload_model', + log_type="upload_model", aliases=["latest", f"iter-{iteration}"], ) @@ -216,7 +213,7 @@ def log(msg: str | dict, log_type: str = 'progress', **kwargs): final_model_path: Path = output_dir / TRAIN_SAVE_FILES.model_final_zanj log(f"Saving final model to {final_model_path.as_posix()}") zanj.save(model, final_model_path) - log(final_model_path, log_type='upload_model', aliases=["latest", "final"]) + log(final_model_path, log_type="upload_model", aliases=["latest", "final"]) log("Done training!") diff --git a/tests/unit/maze_transformer/training/test_offline_training.py b/tests/unit/maze_transformer/training/test_offline_training.py new file mode 100644 index 00000000..3703106f --- /dev/null +++ b/tests/unit/maze_transformer/training/test_offline_training.py @@ -0,0 +1,18 @@ +from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer +from maze_transformer.training.train_model import TrainingResult, train_model + + +def test_train_model(): + cfg: ConfigHolder = ConfigHolder.get_config_multisource( + cfg_names=("test-g3-n5-a_dfs-h75556", "nano-v1", "test-v1"), + ) + cfg.dataset_cfg.n_mazes = 10 + result: TrainingResult = train_model( + base_path="tests/_temp/test_train_model", + wandb_project=None, + cfg=cfg, + do_generate_dataset=True, + ) + + assert isinstance(result.model, ZanjHookedTransformer) + assert result.model.zanj_model_config == cfg From 08f60a1fea3a2d0018d54697ed2a90ddfd50e79f Mon Sep 17 00:00:00 2001 From: naveenarun Date: Sat, 27 Jan 2024 15:56:51 -0600 Subject: [PATCH 4/4] Ran black v24.1.0 on maze_transformers/ and tests/ --- maze_transformer/evaluation/path_evals.py | 3 +-- .../mechinterp/direct_logit_attribution.py | 19 ++++++------- .../mechinterp/logit_attrib_task.py | 6 ++--- maze_transformer/mechinterp/logit_diff.py | 18 ++++++------- maze_transformer/mechinterp/plot_attention.py | 6 ++--- .../mechinterp/residual_stream_structure.py | 14 +++++----- maze_transformer/training/config.py | 27 ++++++++++--------- maze_transformer/training/train_save_files.py | 10 +++---- tests/integration/test_eval_model.py | 1 + .../unit/maze_transformer/test_tokenizers.py | 7 ++--- .../training/test_model_loading_old.py | 1 + 11 files changed, 57 insertions(+), 55 deletions(-) diff --git a/maze_transformer/evaluation/path_evals.py b/maze_transformer/evaluation/path_evals.py index 98eff231..361d02c0 100644 --- a/maze_transformer/evaluation/path_evals.py +++ b/maze_transformer/evaluation/path_evals.py @@ -23,8 +23,7 @@ def __call__( maze: LatticeMaze | None = None, solution: CoordArray | None = None, prediction: CoordArray | None = None, - ) -> float: - ... + ) -> float: ... def path_as_segments_iter(path: CoordArray) -> typing.Iterable[tuple]: diff --git a/maze_transformer/mechinterp/direct_logit_attribution.py b/maze_transformer/mechinterp/direct_logit_attribution.py index 98c13d19..f0ce80ef 100644 --- a/maze_transformer/mechinterp/direct_logit_attribution.py +++ b/maze_transformer/mechinterp/direct_logit_attribution.py @@ -126,8 +126,9 @@ def plot_direct_logit_attribution( answer_tokens: Int[torch.Tensor, "n_mazes"], do_neurons: bool = False, show: bool = True, - layer_index_normalization: typing.Callable[[float, int], float] - | None = lambda contrib, layer_idx: contrib, + layer_index_normalization: ( + typing.Callable[[float, int], float] | None + ) = lambda contrib, layer_idx: contrib, ) -> tuple[plt.Figure, plt.Axes, dict[str, Float[np.ndarray, "layer head/neuron"]]]: """compute, process, and plot direct logit attribution @@ -135,13 +136,13 @@ def plot_direct_logit_attribution( by default, its the identity map for contribs: `layer_index_normalization: typing.Callable[[float, int], float]|None = lambda contrib, layer_idx: contrib` """ - dla_data: dict[ - str, Float[np.ndarray, "layer head/neuron"] - ] = compute_direct_logit_attribution( - model=model, - cache=cache, - answer_tokens=answer_tokens, - do_neurons=do_neurons, + dla_data: dict[str, Float[np.ndarray, "layer head/neuron"]] = ( + compute_direct_logit_attribution( + model=model, + cache=cache, + answer_tokens=answer_tokens, + do_neurons=do_neurons, + ) ) if layer_index_normalization is not None: dla_data = { diff --git a/maze_transformer/mechinterp/logit_attrib_task.py b/maze_transformer/mechinterp/logit_attrib_task.py index 6ef67363..2b6600e2 100644 --- a/maze_transformer/mechinterp/logit_attrib_task.py +++ b/maze_transformer/mechinterp/logit_attrib_task.py @@ -22,8 +22,7 @@ def get_token_first_index(search_token: str, token_list: list[str]) -> int: class DLAProtocol(typing.Protocol): """should take a dataset's tokens, and return a tuple of (prompts, targets)""" - def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup: - ... + def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup: ... class DLAProtocolFixed(typing.Protocol): @@ -32,8 +31,7 @@ class DLAProtocolFixed(typing.Protocol): this variant signifies it's ready to be used -- no keyword arguments are needed """ - def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup: - ... + def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup: ... def token_after_fixed_start_token( diff --git a/maze_transformer/mechinterp/logit_diff.py b/maze_transformer/mechinterp/logit_diff.py index 87976659..9474bbba 100644 --- a/maze_transformer/mechinterp/logit_diff.py +++ b/maze_transformer/mechinterp/logit_diff.py @@ -86,9 +86,9 @@ def logit_diff_residual_stream( vocab_tensor: Float[torch.Tensor, "d_vocab"] = torch.arange( d_vocab, dtype=torch.long ) - vocab_residual_directions: Float[ - torch.Tensor, "d_vocab d_model" - ] = model.tokens_to_residual_directions(vocab_tensor) + vocab_residual_directions: Float[torch.Tensor, "d_vocab d_model"] = ( + model.tokens_to_residual_directions(vocab_tensor) + ) # get embedding of answer tokens answer_residual_directions = vocab_residual_directions[tokens_correct] # get the directional difference between logits and corrent and logits on {all other tokens, comparison tokens} @@ -108,12 +108,12 @@ def logit_diff_residual_stream( ][:, -1, :] # scaling the values in residual stream with layer norm - scaled_final_token_residual_stream: Float[ - torch.Tensor, "samples d_model" - ] = cache.apply_ln_to_stack( - final_token_residual_stream, - layer=-1, - pos_slice=-1, + scaled_final_token_residual_stream: Float[torch.Tensor, "samples d_model"] = ( + cache.apply_ln_to_stack( + final_token_residual_stream, + layer=-1, + pos_slice=-1, + ) ) # measure similarity between the logit diff directions and the residual stream at final layer directions diff --git a/maze_transformer/mechinterp/plot_attention.py b/maze_transformer/mechinterp/plot_attention.py index c38c4954..c39abfac 100644 --- a/maze_transformer/mechinterp/plot_attention.py +++ b/maze_transformer/mechinterp/plot_attention.py @@ -289,9 +289,9 @@ def mazeplot_attention( node_values=node_values, color_map=cmap, target_token_coord=target_coord, - preceeding_tokens_coords=[final_prompt_coord] - if final_prompt_coord is not None - else None, + preceeding_tokens_coords=( + [final_prompt_coord] if final_prompt_coord is not None else None + ), colormap_center=colormap_center_val, colormap_max=colormap_max, hide_colorbar=hide_colorbar, diff --git a/maze_transformer/mechinterp/residual_stream_structure.py b/maze_transformer/mechinterp/residual_stream_structure.py index 1396e0e3..6dffb196 100644 --- a/maze_transformer/mechinterp/residual_stream_structure.py +++ b/maze_transformer/mechinterp/residual_stream_structure.py @@ -68,9 +68,11 @@ def process_tokens_for_pca(tokenizer: MazeTokenizer) -> list[TokenPlottingInfo]: tokenizer.token_arr, tokens_coords, [ - coordinate_to_color(coord, max_val=max_coord) - if isinstance(coord, tuple) - else (0.0, 1.0, 0.0) + ( + coordinate_to_color(coord, max_val=max_coord) + if isinstance(coord, tuple) + else (0.0, 1.0, 0.0) + ) for coord in tokens_coords ], ) @@ -249,9 +251,9 @@ def compute_distances_and_correlation( # embedding_distances /= embedding_distances.max() # Convert the distances to a square matrix - embedding_distances_matrix: Float[ - np.ndarray, "n_coord_tokens n_coord_tokens" - ] = squareform(embedding_distances) + embedding_distances_matrix: Float[np.ndarray, "n_coord_tokens n_coord_tokens"] = ( + squareform(embedding_distances) + ) # Calculate the correlation between the embedding and coordinate distances coordinate_coordinates: Float[np.ndarray, "n_coord_tokens 2"] = np.array( diff --git a/maze_transformer/training/config.py b/maze_transformer/training/config.py index 26e27365..e05f1a52 100644 --- a/maze_transformer/training/config.py +++ b/maze_transformer/training/config.py @@ -214,7 +214,9 @@ def get_intervals( ) except ValueError as e: - _debug_vals: str = f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}" + _debug_vals: str = ( + f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}" + ) raise ValueError(f"{_debug_vals}\ntriggered error:\n{e}") from e # disable if set to 0 or negative @@ -235,9 +237,9 @@ def get_intervals( # actually return the intervals if mod_batch_size: return { - k: max(1, v // self.batch_size) - if isinstance(v, int) - else v # if float, leave it as is since its float("inf") + k: ( + max(1, v // self.batch_size) if isinstance(v, int) else v + ) # if float, leave it as is since its float("inf") for k, v in intervals_new.items() } else: @@ -459,9 +461,11 @@ def summary(self) -> str: "model_cfg": self.model_cfg.summary(), "train_cfg": self.train_cfg.summary(), "pretrainedtokenizer_kwargs": self.pretrainedtokenizer_kwargs, - "maze_tokenizer": self.maze_tokenizer.summary() - if self.maze_tokenizer is not None - else None, + "maze_tokenizer": ( + self.maze_tokenizer.summary() + if self.maze_tokenizer is not None + else None + ), } @property @@ -655,12 +659,9 @@ def _load_state_dict_wrapper( self.zanj_model_config.model_cfg.weight_processing["are_layernorms_folded"] or fold_ln ) - self.zanj_model_config.model_cfg.weight_processing[ - "are_weights_processed" - ] = self.zanj_model_config.model_cfg.weight_processing[ - "are_weights_processed" - ] or ( - not recover_exact + self.zanj_model_config.model_cfg.weight_processing["are_weights_processed"] = ( + self.zanj_model_config.model_cfg.weight_processing["are_weights_processed"] + or (not recover_exact) ) self.load_and_process_state_dict( diff --git a/maze_transformer/training/train_save_files.py b/maze_transformer/training/train_save_files.py index 8cc833a7..fc1a388a 100644 --- a/maze_transformer/training/train_save_files.py +++ b/maze_transformer/training/train_save_files.py @@ -20,12 +20,10 @@ class TRAIN_SAVE_FILES: config_holder: str = "config.json" checkpoints: str = "checkpoints" log: str = "log.jsonl" - model_checkpt_zanj: Callable[ - [int], str - ] = lambda iteration: f"model.iter_{iteration}.zanj" + model_checkpt_zanj: Callable[[int], str] = ( + lambda iteration: f"model.iter_{iteration}.zanj" + ) model_final_zanj: str = "model.final.zanj" - model_run_dir: Callable[ - [ConfigHolder], str - ] = ( + model_run_dir: Callable[[ConfigHolder], str] = ( lambda cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" ) diff --git a/tests/integration/test_eval_model.py b/tests/integration/test_eval_model.py index 7f2a327b..b97ed9e4 100644 --- a/tests/integration/test_eval_model.py +++ b/tests/integration/test_eval_model.py @@ -5,6 +5,7 @@ a HookedTransformer with folding etc., as they would be from just applying the model to the input """ + import warnings from pathlib import Path diff --git a/tests/unit/maze_transformer/test_tokenizers.py b/tests/unit/maze_transformer/test_tokenizers.py index 8b717234..a6e7a854 100644 --- a/tests/unit/maze_transformer/test_tokenizers.py +++ b/tests/unit/maze_transformer/test_tokenizers.py @@ -4,6 +4,7 @@ We may want a separate set of tests for different tokenization schemes """ + from itertools import product import torch @@ -81,11 +82,11 @@ def test_tokenization_encoding( ) def test_to_ascii(tok_mode): # Check that the ascii encoding works for multiple different inputs - maze_str_tokens: list[ - str - ] = """ (1,1) <--> (2,1) ; (2,0) <--> (1,0) ; (0,1) <--> (0,0) ; + maze_str_tokens: list[str] = ( + """ (1,1) <--> (2,1) ; (2,0) <--> (1,0) ; (0,1) <--> (0,0) ; (2,2) <--> (2,1) ; (2,0) <--> (2,1) ; (0,2) <--> (1,2) ; (0,0) <--> (1,0) ; (0,2) <--> (0,1) ; (0,0) (2,1) (0,0) (1,0) (2,0) (2,1) """.split() + ) target: list[str] = [ "#######", diff --git a/tests/unit/maze_transformer/training/test_model_loading_old.py b/tests/unit/maze_transformer/training/test_model_loading_old.py index d289510b..0e32ca72 100644 --- a/tests/unit/maze_transformer/training/test_model_loading_old.py +++ b/tests/unit/maze_transformer/training/test_model_loading_old.py @@ -1,6 +1,7 @@ """ test loading of old style models """ + import json import pytest