diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py index 7ae6d0b0a..d33489283 100644 --- a/graphium/config/_loader.py +++ b/graphium/config/_loader.py @@ -371,7 +371,7 @@ def load_trainer( cfg_trainer = deepcopy(config["trainer"]) # Define the IPU plugin if required - strategy = config["trainer"]["trainer"].get("strategy", "auto") + strategy = cfg_trainer["trainer"].pop("strategy", "auto") if accelerator_type == "ipu": ipu_opts, ipu_inference_opts = _get_ipu_opts(config)