From 85ee878589ac16880808432614446ed396a7cae5 Mon Sep 17 00:00:00 2001 From: mivanit Date: Thu, 22 Aug 2024 01:04:54 -0600 Subject: [PATCH] format --- maze_transformer/training/config.py | 10 ---------- maze_transformer/training/train_model.py | 3 --- 2 files changed, 13 deletions(-) diff --git a/maze_transformer/training/config.py b/maze_transformer/training/config.py index 43522d3..01dced1 100644 --- a/maze_transformer/training/config.py +++ b/maze_transformer/training/config.py @@ -589,25 +589,17 @@ def get_config_multisource( dataset_cfg_name, model_cfg_name, train_cfg_name = cfg_names # assemble the collective name name = f"multsrc_{dataset_cfg_name}_{model_cfg_name}_{train_cfg_name}" - print(f"gcm 591: {dataset_cfg_name = }") else: # 4 names if collective name, unpack it dataset_cfg_name, model_cfg_name, train_cfg_name, name = cfg_names - print(f"gcm 595: {dataset_cfg_name = }") - try: # try to actually assemble the configuration by looking up names in dicts - print(f"gcm 599: {dataset_cfg_name = }") - for k, v in MAZE_DATASET_CONFIGS.items(): - print(f"{k}: {v.summary()}") - config = ConfigHolder( name=name, dataset_cfg=copy.deepcopy(MAZE_DATASET_CONFIGS[dataset_cfg_name]), model_cfg=copy.deepcopy(GPT_CONFIGS[model_cfg_name]), train_cfg=copy.deepcopy(TRAINING_CONFIGS[train_cfg_name]), ) - print(f"gcm 612: {config.dataset_cfg.summary() = }") except KeyError as e: # exception handling for missing keys case raise KeyError( @@ -620,14 +612,12 @@ def get_config_multisource( raise ValueError( "Must provide exactly one of cfg, cfg_file, or cfg_names. this state should be unreachable btw." ) - print(f"gcm 604: {config.dataset_cfg.summary() = }") # update config with kwargs if kwargs_in: kwargs_dict: dict = kwargs_to_nested_dict( kwargs_in, sep=".", strip_prefix="cfg.", when_unknown_prefix="raise" ) config.update_from_nested_dict(kwargs_dict) - print(f"gcm 611: {config.dataset_cfg.summary() = }") return config diff --git a/maze_transformer/training/train_model.py b/maze_transformer/training/train_model.py index 16d705e..bfc8d7c 100644 --- a/maze_transformer/training/train_model.py +++ b/maze_transformer/training/train_model.py @@ -62,7 +62,6 @@ def train_model( - model config names: {model_cfg_names} - train config names: {train_cfg_names} """ - print(cfg.dataset_cfg.summary()) if help: print(train_model.__doc__) return @@ -110,7 +109,6 @@ def train_model( logger.progress("Summary logged, getting dataset") # load dataset - print(cfg.dataset_cfg.summary()) if dataset is None: dataset = MazeDataset.from_config( cfg=cfg.dataset_cfg, @@ -151,7 +149,6 @@ def train_model( ) logger.progress(f"finished getting training dataset with {len(dataset)} samples") - print(f"{len(dataset) = }") # validation dataset, if applicable val_dataset: MazeDataset | None = None if cfg.train_cfg.validation_dataset_cfg is not None: