forked from osmr/imgclsmob
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_gl.py
93 lines (78 loc) · 2.97 KB
/
eval_gl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import argparse
from common.logger_utils import initialize_logging
from gluon.utils import prepare_mx_context, prepare_model
from gluon.utils import get_composite_metric
from gluon.cls_eval_utils import add_eval_cls_parser_arguments, test
from gluon.cls_eval_utils import get_dataset_metainfo
from gluon.cls_eval_utils import get_batch_fn
from gluon.cls_eval_utils import get_val_data_source
def parse_args():
parser = argparse.ArgumentParser(
description="Evaluate a model for image classification (Gluon)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--dataset",
type=str,
default="ImageNet1K_rec",
help="dataset name. options are ImageNet1K, ImageNet1K_rec, CUB_200_2011, CIFAR10, CIFAR100, SVHN")
parser.add_argument(
"--work-dir",
type=str,
default=os.path.join("..", "imgclsmob_data"),
help="path to working directory only for dataset root path preset")
args, _ = parser.parse_known_args()
dataset_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
dataset_metainfo.add_dataset_parser_arguments(
parser=parser,
work_dir_path=args.work_dir)
add_eval_cls_parser_arguments(parser)
args = parser.parse_args()
return args
def main():
args = parse_args()
_, log_file_exist = initialize_logging(
logging_dir_path=args.save_dir,
logging_file_name=args.logging_file_name,
script_args=args,
log_packages=args.log_packages,
log_pip_packages=args.log_pip_packages)
ctx, batch_size = prepare_mx_context(
num_gpus=args.num_gpus,
batch_size=args.batch_size)
net = prepare_model(
model_name=args.model,
use_pretrained=args.use_pretrained,
pretrained_model_file_path=args.resume.strip(),
dtype=args.dtype,
classes=args.num_classes,
in_channels=args.in_channels,
do_hybridize=(not args.calc_flops),
ctx=ctx)
assert (hasattr(net, "in_size"))
input_image_size = net.in_size
ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset)
ds_metainfo.update(args=args)
val_data = get_val_data_source(
ds_metainfo=ds_metainfo,
batch_size=batch_size,
num_workers=args.num_workers)
batch_fn = get_batch_fn(use_imgrec=ds_metainfo.use_imgrec)
assert (args.use_pretrained or args.resume.strip() or args.calc_flops_only)
test(
net=net,
val_data=val_data,
batch_fn=batch_fn,
data_source_needs_reset=ds_metainfo.use_imgrec,
val_metric=get_composite_metric(ds_metainfo.val_metric_names),
dtype=args.dtype,
ctx=ctx,
input_image_size=input_image_size,
in_channels=args.in_channels,
# calc_weight_count=(not log_file_exist),
calc_weight_count=True,
calc_flops=args.calc_flops,
calc_flops_only=args.calc_flops_only,
extended_log=True)
if __name__ == "__main__":
main()