diff --git a/graphium/cli/train_finetune_test.py b/graphium/cli/train_finetune_test.py index 8cddcd63c..35fa993b8 100644 --- a/graphium/cli/train_finetune_test.py +++ b/graphium/cli/train_finetune_test.py @@ -70,15 +70,13 @@ def get_replication_factor(cfg): def get_gradient_accumulation_factor(cfg): + """ + WARNING: This MUST be called after accelerator overrides have been applied + (i.e. after `load_accelerator` has been called) + """ try: # Navigate through the nested dictionaries and get the gradient accumulation factor - grad_accumulation_factor = ( - cfg.get("accelerator", {}) - .get("config_override", {}) - .get("trainer", {}) - .get("trainer", {}) - .get("accumulate_grad_batches", 1) - ) + grad_accumulation_factor = cfg.get("trainer", {}).get("trainer", {}).get("accumulate_grad_batches", 1) # Ensure that the extracted value is an integer return int(grad_accumulation_factor) @@ -90,15 +88,13 @@ def get_gradient_accumulation_factor(cfg): def get_training_batch_size(cfg): + """ + WARNING: This MUST be called after accelerator overrides have been applied + (i.e. after `load_accelerator` has been called) + """ try: # Navigate through the nested dictionaries and get the training batch size - batch_size_training = ( - cfg.get("accelerator", {}) - .get("config_override", {}) - .get("datamodule", {}) - .get("args", {}) - .get("batch_size_training", 1) - ) + batch_size_training = cfg.get("datamodule", {}).get("args", {}).get("batch_size_training", 1) # Ensure that the extracted value is an integer return int(batch_size_training) @@ -109,6 +105,23 @@ def get_training_batch_size(cfg): return 1 +def get_training_device_iterations(cfg): + try: + ipu_config = cfg.get("accelerator", {}).get("ipu_config", []) + for item in ipu_config: + if "deviceIterations" in item: + # Extract the number between parentheses + start = item.find("(") + 1 + end = item.find(")") + if start != 0 and end != -1: + return int(item[start:end]) + except Exception as e: + print(f"An error occurred: {e}") + + # Return default value if deviceIterations is not found or an error occurred + return 1 + + def run_training_finetuning_testing(cfg: DictConfig) -> None: """ The main (pre-)training and fine-tuning loop. @@ -141,12 +154,6 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: st = timeit.default_timer() - replicas = get_replication_factor(cfg) - gradient_acc = get_gradient_accumulation_factor(cfg) - micro_bs = get_training_batch_size(cfg) - - global_bs = replicas * gradient_acc * micro_bs - # Initialize wandb only on first rank if os.environ.get("RANK", "0") == "0": # Disable wandb if the user is not logged in. @@ -185,6 +192,14 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: ## Metrics metrics = load_metrics(cfg) + # Note: these MUST be called after `cfg, accelerator = load_accelerator(cfg)` + replicas = get_replication_factor(cfg) + gradient_acc = get_gradient_accumulation_factor(cfg) + micro_bs = get_training_batch_size(cfg) + device_iterations = get_training_device_iterations(cfg) + + global_bs = replicas * gradient_acc * micro_bs * device_iterations + ## Predictor predictor = load_predictor( config=cfg,