forked from plemeri/InSPyReNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark.py
More file actions
55 lines (43 loc) · 1.53 KB
/
benchmark.py
File metadata and controls
55 lines (43 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
import torch
import os
import sys
from thop import profile, clever_format
import warnings
warnings.filterwarnings("ignore")
filepath = os.path.split(__file__)[0]
repopath = os.path.split(filepath)[0]
sys.path.append(repopath)
from lib import *
from lib.optim import *
from data.dataloader import *
from utils.misc import *
def _args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/InSPyReNet_SwinB.yaml')
parser.add_argument('--input_size', type=int, nargs='+', default=[384, 384])
parser.add_argument('--verbose', action='store_true', default=False)
return parser.parse_args()
def benchmark(opt, args):
model = Simplify(eval(opt.Model.name)(**opt.Model))
model = model.cuda()
input = torch.rand(1, 3, *args.input_size)
input = input.cuda()
macs, params = profile(model, inputs=(input, ), verbose=False)
macs, params = clever_format([macs, params], "%.3f")
with torch.no_grad():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for i in range(10):
out = model(input)
end.record()
# Waits for everything to finish running
torch.cuda.synchronize()
print('Model:', opt.Model.name)
print('MACs:', macs, 'Params:', params)
print('Throughput:', start.elapsed_time(end) / 10, 'msec')
if __name__ == '__main__':
args = _args()
opt = load_config(args.config)
benchmark(opt, args)