Skip to content

Commit

Permalink
Updating get_checkpoint_path
Browse files Browse the repository at this point in the history
  • Loading branch information
WenkelF committed Aug 31, 2023
1 parent 8f1ddfb commit 3cf2fb5
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions graphium/config/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,15 +589,15 @@ def get_checkpoint_path(config: Union[omegaconf.DictConfig, Dict[str, Any]]) ->

cfg_trainer = config["trainer"]

if "model_checkpoint" in cfg_trainer.keys():
dirpath = cfg_trainer["model_checkpoint"]["dirpath"] + str(cfg_trainer["seed"]) + "/"
filename = config.get("ckpt_name_for_testing", "last") + ".ckpt"
else:
raise ValueError("Empty checkpoint section in config file")
path = config.get("ckpt_name_for_testing", "last.ckpt")
if path in GRAPHIUM_PRETRAINED_MODELS_DICT or fs.exists(path):
return path

checkpoint_path = fs.join(dirpath, filename)
if "model_checkpoint" in cfg_trainer.keys():
dirpath = cfg_trainer["model_checkpoint"]["dirpath"]
path = fs.join(dirpath, path)

if not fs.exists(checkpoint_path):
raise ValueError(f"Checkpoint path `{checkpoint_path}` does not exist")
if not fs.exists(path):
raise ValueError(f"Checkpoint path `{path}` does not exist")

return checkpoint_path
return path

0 comments on commit 3cf2fb5

Please sign in to comment.