-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathbase.py
36 lines (26 loc) · 1002 Bytes
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import json
from .nnet.py_factory import NetworkFactory
class Base(object):
def __init__(self, db, nnet, func, model=None):
super(Base, self).__init__()
self._db = db
self._nnet = nnet
self._func = func
if model is not None:
self._nnet.load_pretrained_params(model)
self._nnet.cuda()
self._nnet.eval_mode()
def _inference(self, image, *args, **kwargs):
return self._func(self._db, self._nnet, image.copy(), *args, **kwargs)
def __call__(self, image, *args, **kwargs):
categories = self._db.configs["categories"]
bboxes = self._inference(image, *args, **kwargs)
return {self._db.cls2name(j): bboxes[j] for j in range(1, categories + 1)}
def load_cfg(cfg_file):
with open(cfg_file, "r") as f:
cfg = json.load(f)
cfg_sys = cfg["system"]
cfg_db = cfg["db"]
return cfg_sys, cfg_db
def load_nnet(cfg_sys, model):
return NetworkFactory(cfg_sys, model)