Skip to content

Commit

Permalink
Merge pull request #375 from datamol-io/hyperparameter_change
Browse files Browse the repository at this point in the history
Hyperparameter change fr toymix baseline configs
  • Loading branch information
DomInvivo authored Jul 2, 2023
2 parents 095147a + 8c705b9 commit e318ff0
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ accelerator:
max_num_edges_per_graph: 80
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 44 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 80
max_num_nodes_per_graph: 100 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 200
# Data handling-related
batch_size_training: 5
batch_size_inference: 5
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
Expand Down Expand Up @@ -259,7 +259,7 @@ predictor:
zinc: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-3 # warmup can be scheduled using torch_scheduler_kwargs
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
Expand Down
10 changes: 3 additions & 7 deletions expts/run_validation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
load_trainer,
save_params_to_wandb,
load_accelerator,
load_yaml_config,
)
from graphium.utils.safe_run import SafeRun
from graphium.utils.command_line_utils import update_config, get_anchors_and_aliases


# WandB
Expand Down Expand Up @@ -113,13 +113,9 @@ def main(cfg: DictConfig, run_name: str = "main", add_date_time: bool = True) ->
parser = argparse.ArgumentParser()
parser.add_argument("--config", help="Path to the config file", default=None)

args, unknown = parser.parse_known_args()
# Optionally parse the config with the command line
args, unknown_args = parser.parse_known_args()
if args.config is not None:
CONFIG_FILE = args.config
cfg = load_yaml_config(CONFIG_FILE, MAIN_DIR, unknown_args)

with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f:
cfg = yaml.safe_load(f)
refs = get_anchors_and_aliases(CONFIG_FILE)
cfg = update_config(cfg, unknown, refs)
main(cfg)
2 changes: 1 addition & 1 deletion requirements_ipu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ fastparquet
torch-scatter==2.1.0
torch-sparse==0.6.15
torchvision==0.14.1+cpu
lightning-graphcore @ git+https://github.com/Lightning-AI/lightning-Graphcore
lightning-graphcore @ git+https://github.com/Lightning-AI/lightning-Graphcore

0 comments on commit e318ff0

Please sign in to comment.