Skip to content

Commit

Permalink
Merge pull request #483 from datamol-io/fix-samples-seen-logging
Browse files Browse the repository at this point in the history
Fix `samples_seen` logging
  • Loading branch information
callumm-graphcore committed Oct 26, 2023
2 parents 983bf6c + cc6adba commit e5fa686
Showing 1 changed file with 35 additions and 20 deletions.
55 changes: 35 additions & 20 deletions graphium/cli/train_finetune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,13 @@ def get_replication_factor(cfg):


def get_gradient_accumulation_factor(cfg):
"""
WARNING: This MUST be called after accelerator overrides have been applied
(i.e. after `load_accelerator` has been called)
"""
try:
# Navigate through the nested dictionaries and get the gradient accumulation factor
grad_accumulation_factor = (
cfg.get("accelerator", {})
.get("config_override", {})
.get("trainer", {})
.get("trainer", {})
.get("accumulate_grad_batches", 1)
)
grad_accumulation_factor = cfg.get("trainer", {}).get("trainer", {}).get("accumulate_grad_batches", 1)

# Ensure that the extracted value is an integer
return int(grad_accumulation_factor)
Expand All @@ -90,15 +88,13 @@ def get_gradient_accumulation_factor(cfg):


def get_training_batch_size(cfg):
"""
WARNING: This MUST be called after accelerator overrides have been applied
(i.e. after `load_accelerator` has been called)
"""
try:
# Navigate through the nested dictionaries and get the training batch size
batch_size_training = (
cfg.get("accelerator", {})
.get("config_override", {})
.get("datamodule", {})
.get("args", {})
.get("batch_size_training", 1)
)
batch_size_training = cfg.get("datamodule", {}).get("args", {}).get("batch_size_training", 1)

# Ensure that the extracted value is an integer
return int(batch_size_training)
Expand All @@ -109,6 +105,23 @@ def get_training_batch_size(cfg):
return 1


def get_training_device_iterations(cfg):
try:
ipu_config = cfg.get("accelerator", {}).get("ipu_config", [])
for item in ipu_config:
if "deviceIterations" in item:
# Extract the number between parentheses
start = item.find("(") + 1
end = item.find(")")
if start != 0 and end != -1:
return int(item[start:end])
except Exception as e:
print(f"An error occurred: {e}")

# Return default value if deviceIterations is not found or an error occurred
return 1


def run_training_finetuning_testing(cfg: DictConfig) -> None:
"""
The main (pre-)training and fine-tuning loop.
Expand Down Expand Up @@ -141,12 +154,6 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:

st = timeit.default_timer()

replicas = get_replication_factor(cfg)
gradient_acc = get_gradient_accumulation_factor(cfg)
micro_bs = get_training_batch_size(cfg)

global_bs = replicas * gradient_acc * micro_bs

# Initialize wandb only on first rank
if os.environ.get("RANK", "0") == "0":
# Disable wandb if the user is not logged in.
Expand Down Expand Up @@ -185,6 +192,14 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
## Metrics
metrics = load_metrics(cfg)

# Note: these MUST be called after `cfg, accelerator = load_accelerator(cfg)`
replicas = get_replication_factor(cfg)
gradient_acc = get_gradient_accumulation_factor(cfg)
micro_bs = get_training_batch_size(cfg)
device_iterations = get_training_device_iterations(cfg)

global_bs = replicas * gradient_acc * micro_bs * device_iterations

## Predictor
predictor = load_predictor(
config=cfg,
Expand Down

0 comments on commit e5fa686

Please sign in to comment.