diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py index 3c7a654e9..d049a6f4e 100644 --- a/graphium/config/_loader.py +++ b/graphium/config/_loader.py @@ -372,7 +372,7 @@ def load_trainer( cfg_trainer = deepcopy(config["trainer"]) # Define the IPU plugin if required - strategy = "auto" + strategy = config["trainer"]["trainer"].get("strategy", "auto") if accelerator_type == "ipu": ipu_opts, ipu_inference_opts = _get_ipu_opts(config) @@ -385,6 +385,9 @@ def load_trainer( precision=config["trainer"]["trainer"].get("precision"), ) + if strategy != "auto": + raise ValueError("IPUs selected, but strategy is not set to 'auto'") + from lightning_graphcore import IPUStrategy strategy = IPUStrategy(training_opts=training_opts, inference_opts=inference_opts)