forked from mit-han-lab/bevfusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport.py
executable file
·109 lines (91 loc) · 3.27 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import os
import time
import warnings
import mmcv
import onnx
import torch
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint, wrap_fp16_model
from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model
from mmdet.apis import multi_gpu_test, set_random_seed
from mmdet.datasets import replace_ImageToTensor
from onnxsim import simplify
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(description="MMDet test (and eval) a model")
parser.add_argument("config", help="test config file path")
parser.add_argument("checkpoint", help="checkpoint file")
parser.add_argument(
"--cfg-options",
nargs="+",
action=DictAction,
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file. If the value to "
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
"Note that the quotation marks are necessary and that no white space "
"is allowed.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=False,
shuffle=False,
)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu")
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if "CLASSES" in checkpoint.get("meta", {}):
model.CLASSES = checkpoint["meta"]["CLASSES"]
else:
model.CLASSES = dataset.CLASSES
model.eval()
with torch.no_grad():
for data in data_loader:
img = [torch.cat([data["img"][0].data[0]] * 6, dim=0)]
metas = data["metas"][0].data
from functools import partial
model.forward = partial(
model.forward_test,
metas=metas,
rescale=True,
)
torch.onnx.export(
model,
img,
"model.onnx",
input_names=["input"],
opset_version=13,
do_constant_folding=True,
)
model = onnx.load("model.onnx")
model, _ = simplify(model)
onnx.save(model, "model.onnx")
return
if __name__ == "__main__":
main()