From 77f9647d88d71cd7c69ebea11c22c4644948fc75 Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Thu, 26 Oct 2023 10:23:32 +0000 Subject: [PATCH 1/4] Fix `samples_seen` logging --- graphium/cli/train_finetune_test.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/graphium/cli/train_finetune_test.py b/graphium/cli/train_finetune_test.py index 8cddcd63c..9d0ea38c3 100644 --- a/graphium/cli/train_finetune_test.py +++ b/graphium/cli/train_finetune_test.py @@ -108,6 +108,23 @@ def get_training_batch_size(cfg): # Return default value if an error occurred return 1 +def get_device_iterations(cfg): + try: + ipu_config = cfg.get("accelerator", {}).get("ipu_config", []) + for item in ipu_config: + if "deviceIterations" in item: + # Extract the number between parentheses + start = item.find("(") + 1 + end = item.find(")") + if start != 0 and end != -1: + return int(item[start:end]) + except Exception as e: + print(f"An error occurred: {e}") + + # Return default value if deviceIterations is not found or an error occurred + return 1 + + def run_training_finetuning_testing(cfg: DictConfig) -> None: """ @@ -144,8 +161,9 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: replicas = get_replication_factor(cfg) gradient_acc = get_gradient_accumulation_factor(cfg) micro_bs = get_training_batch_size(cfg) + device_iterations = get_training_device_iterations(cfg) - global_bs = replicas * gradient_acc * micro_bs + global_bs = replicas * gradient_acc * micro_bs * device_iterations # Initialize wandb only on first rank if os.environ.get("RANK", "0") == "0": From b41eaaf2e5aa37ead8fe1a373e3c61418a320fe7 Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Thu, 26 Oct 2023 10:26:14 +0000 Subject: [PATCH 2/4] `black` linting --- graphium/cli/train_finetune_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphium/cli/train_finetune_test.py b/graphium/cli/train_finetune_test.py index 9d0ea38c3..349dc77e3 100644 --- a/graphium/cli/train_finetune_test.py +++ b/graphium/cli/train_finetune_test.py @@ -108,6 +108,7 @@ def get_training_batch_size(cfg): # Return default value if an error occurred return 1 + def get_device_iterations(cfg): try: ipu_config = cfg.get("accelerator", {}).get("ipu_config", []) @@ -125,7 +126,6 @@ def get_device_iterations(cfg): return 1 - def run_training_finetuning_testing(cfg: DictConfig) -> None: """ The main (pre-)training and fine-tuning loop. From 6832a09dd510a20fb304cac6290afbd9f7ee7332 Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Thu, 26 Oct 2023 13:01:35 +0000 Subject: [PATCH 3/4] Fix batch size and gradient accumulation functions --- graphium/cli/train_finetune_test.py | 39 +++++++++++++---------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/graphium/cli/train_finetune_test.py b/graphium/cli/train_finetune_test.py index 349dc77e3..827a18cb3 100644 --- a/graphium/cli/train_finetune_test.py +++ b/graphium/cli/train_finetune_test.py @@ -70,15 +70,12 @@ def get_replication_factor(cfg): def get_gradient_accumulation_factor(cfg): + """ + WARNING: This MUST be called after accelerator overrides have been applied + """ try: # Navigate through the nested dictionaries and get the gradient accumulation factor - grad_accumulation_factor = ( - cfg.get("accelerator", {}) - .get("config_override", {}) - .get("trainer", {}) - .get("trainer", {}) - .get("accumulate_grad_batches", 1) - ) + grad_accumulation_factor = cfg.get("trainer", {}).get("trainer", {}).get("accumulate_grad_batches", 1) # Ensure that the extracted value is an integer return int(grad_accumulation_factor) @@ -90,15 +87,12 @@ def get_gradient_accumulation_factor(cfg): def get_training_batch_size(cfg): + """ + WARNING: This MUST be called after accelerator overrides have been applied + """ try: # Navigate through the nested dictionaries and get the training batch size - batch_size_training = ( - cfg.get("accelerator", {}) - .get("config_override", {}) - .get("datamodule", {}) - .get("args", {}) - .get("batch_size_training", 1) - ) + batch_size_training = cfg.get("datamodule", {}).get("args", {}).get("batch_size_training", 1) # Ensure that the extracted value is an integer return int(batch_size_training) @@ -109,7 +103,7 @@ def get_training_batch_size(cfg): return 1 -def get_device_iterations(cfg): +def get_training_device_iterations(cfg): try: ipu_config = cfg.get("accelerator", {}).get("ipu_config", []) for item in ipu_config: @@ -158,13 +152,6 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: st = timeit.default_timer() - replicas = get_replication_factor(cfg) - gradient_acc = get_gradient_accumulation_factor(cfg) - micro_bs = get_training_batch_size(cfg) - device_iterations = get_training_device_iterations(cfg) - - global_bs = replicas * gradient_acc * micro_bs * device_iterations - # Initialize wandb only on first rank if os.environ.get("RANK", "0") == "0": # Disable wandb if the user is not logged in. @@ -203,6 +190,14 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: ## Metrics metrics = load_metrics(cfg) + # Note: these MUST be called after `cfg, accelerator = load_accelerator(cfg)` + replicas = get_replication_factor(cfg) + gradient_acc = get_gradient_accumulation_factor(cfg) + micro_bs = get_training_batch_size(cfg) + device_iterations = get_training_device_iterations(cfg) + + global_bs = replicas * gradient_acc * micro_bs * device_iterations + ## Predictor predictor = load_predictor( config=cfg, From cc6adba267042093736491904c93475b4b6bc80a Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Thu, 26 Oct 2023 14:33:30 +0000 Subject: [PATCH 4/4] Clarify override comment --- graphium/cli/train_finetune_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/graphium/cli/train_finetune_test.py b/graphium/cli/train_finetune_test.py index 827a18cb3..35fa993b8 100644 --- a/graphium/cli/train_finetune_test.py +++ b/graphium/cli/train_finetune_test.py @@ -72,6 +72,7 @@ def get_replication_factor(cfg): def get_gradient_accumulation_factor(cfg): """ WARNING: This MUST be called after accelerator overrides have been applied + (i.e. after `load_accelerator` has been called) """ try: # Navigate through the nested dictionaries and get the gradient accumulation factor @@ -89,6 +90,7 @@ def get_gradient_accumulation_factor(cfg): def get_training_batch_size(cfg): """ WARNING: This MUST be called after accelerator overrides have been applied + (i.e. after `load_accelerator` has been called) """ try: # Navigate through the nested dictionaries and get the training batch size