-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
138 lines (113 loc) · 4.23 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import json
import os
import sys
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning, NumbaWarning
import warnings
warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaWarning)
import numpy as np
import torch
import yaml
from det3d.datasets import build_dataset
from det3d.models import build_detector
from det3d.torchie import Config
from det3d.torchie.apis import (
build_optimizer,
get_root_logger,
init_dist,
set_random_seed,
train_detector,
)
def parse_args():
parser = argparse.ArgumentParser(description="Train a detector")
parser.add_argument("config", help="train config file path")
parser.add_argument("--work_dir", help="the dir to save logs and models")
parser.add_argument("--resume_from", help="the checkpoint file to resume from")
parser.add_argument(
"--validate",
action="store_true",
help="whether to evaluate the checkpoint during training",
)
parser.add_argument(
"--gpus",
type=int,
default=1,
help="number of gpus to use " "(only applicable to non-distributed training)",
)
parser.add_argument("--seed", type=int, default=None, help="random seed")
parser.add_argument(
"--launcher",
choices=["none", "pytorch", "slurm", "mpi"],
default="none",
help="job launcher",
)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument(
"--autoscale-lr",
action="store_true",
help="automatically scale lr with the number of gpus",
)
parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model')
args = parser.parse_args()
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = str(args.local_rank)
return args
def main():
# torch.manual_seed(0)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True
# np.random.seed(0)
args = parse_args()
cfg = Config.fromfile(args.config)
cfg.local_rank = args.local_rank
# update configs according to CLI args
if args.work_dir is not None:
cfg.work_dir = args.work_dir
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.pretrained_model is not None:
cfg.load_from = args.pretrained_model
distributed = False
if "WORLD_SIZE" in os.environ:
distributed = int(os.environ["WORLD_SIZE"]) > 1
if distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
cfg.gpus = torch.distributed.get_world_size()
# init logger before other steps
logger = get_root_logger(cfg.log_level)
logger.info("Distributed training: {}".format(distributed))
logger.info(f"torch.backends.cudnn.benchmark: {torch.backends.cudnn.benchmark}")
if args.local_rank == 0:
# copy important files to backup
backup_dir = os.path.join(cfg.work_dir, "det3d")
os.makedirs(backup_dir, exist_ok=True)
os.system("rsync -ar --exclude=data --exclude=output --exclude=work_dirs --exclude=det3d/ops --exclude=det3d/models/utils/pointnet2 ./ %s " % backup_dir) # TODO
logger.info(f"Backup source files to {cfg.work_dir}det3d")
# set random seeds
if args.seed is not None:
logger.info("Set random seed to {}".format(args.seed))
set_random_seed(args.seed)
model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
datasets.append(build_dataset(cfg.data.val))
if cfg.checkpoint_config is not None:
# save det3d version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
config=cfg.text, CLASSES=datasets[0].CLASSES
)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=args.validate,
logger=logger,
)
if __name__ == "__main__":
main()