21
21
# Downloads all model files again if manifest file is not present
22
22
MANIFEST_FILE = "model_manifest.json"
23
23
24
+ # Valid paths for model-saving specification
25
+ VALID_PATHS = ("script" , "trace" , "torchscript" , "pytorch" , "all" )
26
+
27
+ # Key models selected for benchmarking with their respective paths
24
28
BENCHMARK_MODELS = {
25
- "vgg16" : {"model" : models .vgg16 (weights = None ), "path" : "script" },
26
- "resnet50" : {"model" : models .resnet50 (weights = None ), "path" : "script" },
29
+ "vgg16" : {"model" : models .vgg16 (pretrained = True ), "path" : ["script" , "pytorch" ]},
30
+ "resnet50" : {
31
+ "model" : models .resnet50 (weights = None ),
32
+ "path" : ["script" , "pytorch" ],
33
+ },
27
34
"efficientnet_b0" : {
28
35
"model" : timm .create_model ("efficientnet_b0" , pretrained = True ),
29
- "path" : "script" ,
36
+ "path" : [ "script" , "pytorch" ] ,
30
37
},
31
38
"vit" : {
32
39
"model" : timm .create_model ("vit_base_patch16_224" , pretrained = True ),
@@ -40,18 +47,26 @@ def get(n, m, manifest):
40
47
print ("Downloading {}" .format (n ))
41
48
traced_filename = "models/" + n + "_traced.jit.pt"
42
49
script_filename = "models/" + n + "_scripted.jit.pt"
50
+ pytorch_filename = "models/" + n + "_pytorch.pt"
43
51
x = torch .ones ((1 , 3 , 300 , 300 )).cuda ()
44
- if n == "bert-base-uncased " :
52
+ if n == "bert_base_uncased " :
45
53
traced_model = m ["model" ]
46
54
torch .jit .save (traced_model , traced_filename )
47
55
manifest .update ({n : [traced_filename ]})
48
56
else :
49
57
m ["model" ] = m ["model" ].eval ().cuda ()
50
- if m ["path" ] == "both" or m ["path" ] == "trace" :
58
+
59
+ # Get all desired model save specifications as list
60
+ paths = [m ["path" ]] if isinstance (m ["path" ], str ) else m ["path" ]
61
+
62
+ # Depending on specified model save specifications, save desired model formats
63
+ if any (path in ("all" , "torchscript" , "trace" ) for path in paths ):
64
+ # (TorchScript) Traced model
51
65
trace_model = torch .jit .trace (m ["model" ], [x ])
52
66
torch .jit .save (trace_model , traced_filename )
53
67
manifest .update ({n : [traced_filename ]})
54
- if m ["path" ] == "both" or m ["path" ] == "script" :
68
+ if any (path in ("all" , "torchscript" , "script" ) for path in paths ):
69
+ # (TorchScript) Scripted model
55
70
script_model = torch .jit .script (m ["model" ])
56
71
torch .jit .save (script_model , script_filename )
57
72
if n in manifest .keys ():
@@ -60,6 +75,15 @@ def get(n, m, manifest):
60
75
manifest .update ({n : files })
61
76
else :
62
77
manifest .update ({n : [script_filename ]})
78
+ if any (path in ("all" , "pytorch" ) for path in paths ):
79
+ # (PyTorch Module) model
80
+ torch .save (m ["model" ], pytorch_filename )
81
+ if n in manifest .keys ():
82
+ files = list (manifest [n ]) if type (manifest [n ]) != list else manifest [n ]
83
+ files .append (script_filename )
84
+ manifest .update ({n : files })
85
+ else :
86
+ manifest .update ({n : [script_filename ]})
63
87
return manifest
64
88
65
89
@@ -72,15 +96,35 @@ def download_models(version_matches, manifest):
72
96
for n , m in BENCHMARK_MODELS .items ():
73
97
scripted_filename = "models/" + n + "_scripted.jit.pt"
74
98
traced_filename = "models/" + n + "_traced.jit.pt"
99
+ pytorch_filename = "models/" + n + "_pytorch.pt"
75
100
# Check if model file exists on disk
101
+
102
+ # Extract model specifications as list and ensure all desired formats exist
103
+ paths = [m ["path" ]] if isinstance (m ["path" ], str ) else m ["path" ]
76
104
if (
77
105
(
78
- m ["path" ] == "both"
106
+ any (path == "all" for path in paths )
107
+ and os .path .exists (scripted_filename )
108
+ and os .path .exists (traced_filename )
109
+ and os .path .exists (pytorch_filename )
110
+ )
111
+ or (
112
+ any (path == "torchscript" for path in paths )
79
113
and os .path .exists (scripted_filename )
80
114
and os .path .exists (traced_filename )
81
115
)
82
- or (m ["path" ] == "script" and os .path .exists (scripted_filename ))
83
- or (m ["path" ] == "trace" and os .path .exists (traced_filename ))
116
+ or (
117
+ any (path == "script" for path in paths )
118
+ and os .path .exists (scripted_filename )
119
+ )
120
+ or (
121
+ any (path == "trace" for path in paths )
122
+ and os .path .exists (traced_filename )
123
+ )
124
+ or (
125
+ any (path == "pytorch" for path in paths )
126
+ and os .path .exists (pytorch_filename )
127
+ )
84
128
):
85
129
print ("Skipping {} " .format (n ))
86
130
continue
@@ -90,7 +134,6 @@ def download_models(version_matches, manifest):
90
134
def main ():
91
135
manifest = None
92
136
version_matches = False
93
- manifest_exists = False
94
137
95
138
# Check if Manifest file exists or is empty
96
139
if not os .path .exists (MANIFEST_FILE ) or os .stat (MANIFEST_FILE ).st_size == 0 :
@@ -99,7 +142,6 @@ def main():
99
142
# Creating an empty manifest file for overwriting post setup
100
143
os .system ("touch {}" .format (MANIFEST_FILE ))
101
144
else :
102
- manifest_exists = True
103
145
104
146
# Load manifest if already exists
105
147
with open (MANIFEST_FILE , "r" ) as f :
@@ -129,4 +171,13 @@ def main():
129
171
f .truncate ()
130
172
131
173
132
- main ()
174
+ if __name__ == "__main__" :
175
+ # Ensure all specified desired model formats exist and are valid
176
+ paths = [
177
+ [m ["path" ]] if isinstance (m ["path" ], str ) else m ["path" ]
178
+ for m in BENCHMARK_MODELS .values ()
179
+ ]
180
+ assert all (
181
+ (path in VALID_PATHS ) for path_list in paths for path in path_list
182
+ ), "Not all 'path' attributes in BENCHMARK_MODELS are valid"
183
+ main ()
0 commit comments