Skip to content

Commit 6d3f24b

Browse files
add an utility to save pretrained model
1 parent 0de7d19 commit 6d3f24b

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ You call it for example with:
176176
python tests/run_pretrained_models.py --backend onnxruntime --config tests/run_pretrained_models.yaml --perf perf.csv
177177
```
178178

179+
### <a name="save_pretrained_model"></a>Tool to save pre-trained model
180+
181+
We provide an [utility](tools/save_pretrained_model.py) to save pre-trained model along with its config.
182+
Put `save_pretrained_model(sess, outputs, feed_inputs, save_dir, model_name)` in your last testing step and the pre-trained model and config will be saved under `save_dir/to_onnx`.
183+
Please refer to [tools/save_pretrained_model.py](tools/save_pretrained_model.py) for more information.
184+
179185
# Using the Python API
180186
## TensorFlow to ONNX conversion
181187
In some cases it will be useful to convert the models from TensorFlow to ONNX from a python script. You can use the following API:

tools/save_pretrained_model.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
4+
def save_pretrained_model(sess, outputs, feeds, out_dir, model_name="pretrained"):
5+
try:
6+
import os
7+
to_onnx_path = "{}/to_onnx".format(out_dir)
8+
if not os.path.isdir(to_onnx_path):
9+
os.makedirs(to_onnx_path)
10+
saved_model = "{}/saved_model".format(to_onnx_path)
11+
inputs_path = "{}/inputs.npy".format(to_onnx_path)
12+
pretrained_model_yaml_path = "{}/pretrained.yaml".format(to_onnx_path)
13+
14+
print("===============Save Frozen Graph========================")
15+
print("Save model for tf2onnx: {}".format(to_onnx_path))
16+
# save inputs
17+
inputs = {}
18+
for inp, value in feeds.items():
19+
if isinstance(inp, str):
20+
inputs[inp] = value
21+
else:
22+
inputs[inp.name] = value
23+
np.save(inputs_path, inputs)
24+
print("Saved inputs to {}".format(inputs_path))
25+
26+
# save graph and weights
27+
from tensorflow.saved_model import simple_save
28+
simple_save(sess, saved_model,
29+
{n: i for n,i in zip(inputs.keys(), feeds.keys())},
30+
{op.name: op for op in outputs})
31+
print("Saved model to {}".format(saved_model))
32+
33+
# generate config
34+
pretrained_model_yaml = '''
35+
{}:
36+
model: ./saved_model
37+
model_type: saved_model
38+
input_get: get_ramp
39+
'''.format(model_name)
40+
pretrained_model_yaml += " inputs:\n"
41+
for inp, _ in inputs.items():
42+
pretrained_model_yaml += " \"{}\": np.array(np.load(\"./inputs.npy\")[()][\"{}\"])\n".format(
43+
inp, inp
44+
)
45+
outputs = [op.name for op in outputs]
46+
pretrained_model_yaml += " outputs:\n"
47+
for out in outputs:
48+
pretrained_model_yaml += " - {}\n".format(out)
49+
with open(pretrained_model_yaml_path, "w") as f:
50+
f.write(pretrained_model_yaml)
51+
print("Saved pretrained model yaml to {}".format(pretrained_model_yaml_path))
52+
print("=========================================================")
53+
except Exception as ex:
54+
print("Error: {}".format(ex))
55+
56+
57+
def test():
58+
x_val = np.random.rand(5, 20).astype(np.float32)
59+
y_val = np.random.rand(20, 10).astype(np.float32)
60+
x = tf.placeholder(tf.float32, x_val.shape, name="x")
61+
y = tf.placeholder(tf.float32, y_val.shape, name="y")
62+
z = tf.matmul(x, y)
63+
w = tf.get_variable("weight", [5, 10], dtype=tf.float32)
64+
init = tf.global_variables_initializer()
65+
outputs = [z + w]
66+
feeds = {x: x_val, y: y_val}
67+
with tf.Session() as sess:
68+
sess.run(init)
69+
out = sess.run(outputs, feeds)
70+
# NOTE: Put below snippet after the LAST testing step
71+
save_pretrained_model(sess, outputs, feeds, "./tests", model_name="test")
72+
73+
74+
if __name__ == "__main__":
75+
test()

0 commit comments

Comments
 (0)