Skip to content

Commit caebd90

Browse files
committed
list model
1 parent fa3292a commit caebd90

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

tests/mrt/list_model.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import argparse
2+
from os import path
3+
import os
4+
5+
import gluoncv as cv
6+
7+
model_names_quantized_default = []
8+
9+
def get_model_names_quantized(fpath, model_names_quantized):
10+
for fname in os.listdir(fpath):
11+
nfpath = path.join(fpath, fname)
12+
if path.isdir(nfpath):
13+
get_model_names_quantized(nfpath, model_names_quantized)
14+
else:
15+
model_names_quantized.append(path.splitext(fname)[0])
16+
17+
dir_path = os.path.dirname(os.path.realpath(__file__))
18+
model_zoo = path.join(dir_path, "model_zoo")
19+
get_model_names_quantized(model_zoo, model_names_quantized_default)
20+
parser = argparse.ArgumentParser("")
21+
parser.add_argument("-p", "--prefixes", nargs="*", type=str, default=[])
22+
parser.add_argument("-sq", "--show-quantized", action="store_true")
23+
parser.add_argument("-sqo", "--show-quantized-only", action="store_true")
24+
parser.add_argument("-mnq", "--model-names-quantized", nargs="*", type=str, default=model_names_quantized_default)
25+
26+
if __name__ == "__main__":
27+
args = parser.parse_args()
28+
model_names_quantized = args.model_names_quantized
29+
if args.show_quantized_only:
30+
for model_name in model_names_quantized:
31+
print(model_name)
32+
else:
33+
prefixes = set(args.prefixes)
34+
show_quantized = args.show_quantized
35+
supported_models = set(cv.model_zoo.get_model_list())
36+
for model_name in cv.model_zoo.pretrained_model_list():
37+
if model_name not in supported_models:
38+
continue
39+
if not show_quantized and model_name in model_names_quantized:
40+
continue
41+
if prefixes:
42+
for prefix in prefixes:
43+
if model_name.startswith(prefix):
44+
print(model_name)
45+
else:
46+
print(model_name)

0 commit comments

Comments
 (0)