-
Notifications
You must be signed in to change notification settings - Fork 26
/
main_oc20.py
executable file
·133 lines (116 loc) · 4.39 KB
/
main_oc20.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import copy
import logging
import os
import sys
import time
from pathlib import Path
import submitit
from ocpmodels.common import distutils
from ocpmodels.common.flags import flags
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import (
build_config,
create_grid,
save_experiment_log,
setup_imports,
setup_logging,
)
import nets
import oc20.trainer
import oc20.trainer.dist_setup
class Runner(submitit.helpers.Checkpointable):
def __init__(self):
self.config = None
def __call__(self, config):
setup_logging()
self.config = copy.deepcopy(config)
if args.distributed:
#distutils.setup(config)
oc20.trainer.dist_setup.setup(config)
try:
setup_imports()
self.trainer = registry.get_trainer_class(
config.get("trainer", "energy")
)(
task=config["task"],
model=config["model"],
dataset=config["dataset"],
optimizer=config["optim"],
identifier=config["identifier"],
timestamp_id=config.get("timestamp_id", None),
run_dir=config.get("run_dir", "./"),
is_debug=config.get("is_debug", False),
print_every=config.get("print_every", 10),
seed=config.get("seed", 0),
logger=config.get("logger", "tensorboard"),
local_rank=config["local_rank"],
amp=config.get("amp", False),
cpu=config.get("cpu", False),
slurm=config.get("slurm", {}),
noddp=config.get("noddp", False),
)
# overwrite mode
if config.get('compute_stats', False):
config['mode'] = 'compute_stats'
self.task = registry.get_task_class(config["mode"])(self.config)
self.task.setup(self.trainer)
start_time = time.time()
self.task.run()
distutils.synchronize()
if distutils.is_master():
logging.info(f"Total time taken: {time.time() - start_time}")
finally:
if args.distributed:
distutils.cleanup()
def checkpoint(self, *args, **kwargs):
new_runner = Runner()
self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True)
self.config["checkpoint"] = self.task.chkpt_path
self.config["timestamp_id"] = self.trainer.timestamp_id
if self.trainer.logger is not None:
self.trainer.logger.mark_preempting()
return submitit.helpers.DelayedSubmission(new_runner, self.config)
if __name__ == "__main__":
setup_logging()
parser = flags.get_parser()
args, override_args = parser.parse_known_args()
config = build_config(args, override_args)
if args.submit: # Run on cluster
slurm_add_params = config.get(
"slurm", None
) # additional slurm arguments
if args.sweep_yml: # Run grid search
configs = create_grid(config, args.sweep_yml)
else:
configs = [config]
logging.info(f"Submitting {len(configs)} jobs")
executor = submitit.AutoExecutor(
folder=args.logdir / "%j", slurm_max_num_timeout=3
)
executor.update_parameters(
name=args.identifier,
mem_gb=args.slurm_mem,
timeout_min=args.slurm_timeout * 60,
slurm_partition=args.slurm_partition,
gpus_per_node=args.num_gpus,
cpus_per_task=(config["optim"]["num_workers"] + 1),
tasks_per_node=(args.num_gpus if args.distributed else 1),
nodes=args.num_nodes,
slurm_additional_parameters=slurm_add_params,
)
for config in configs:
config["slurm"] = copy.deepcopy(executor.parameters)
config["slurm"]["folder"] = str(executor.folder)
jobs = executor.map_array(Runner(), configs)
logging.info(
f"Submitted jobs: {', '.join([job.job_id for job in jobs])}"
)
log_file = save_experiment_log(args, jobs, configs)
logging.info(f"Experiment log saved to: {log_file}")
else: # Run locally
Runner()(config)