|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | + |
| 4 | +def save_pretrained_model(sess, outputs, feeds, out_dir, model_name="pretrained"): |
| 5 | + try: |
| 6 | + import os |
| 7 | + to_onnx_path = "{}/to_onnx".format(out_dir) |
| 8 | + if not os.path.isdir(to_onnx_path): |
| 9 | + os.makedirs(to_onnx_path) |
| 10 | + saved_model = "{}/saved_model".format(to_onnx_path) |
| 11 | + inputs_path = "{}/inputs.npy".format(to_onnx_path) |
| 12 | + pretrained_model_yaml_path = "{}/pretrained.yaml".format(to_onnx_path) |
| 13 | + |
| 14 | + print("===============Save Frozen Graph========================") |
| 15 | + print("Save model for tf2onnx: {}".format(to_onnx_path)) |
| 16 | + # save inputs |
| 17 | + inputs = {} |
| 18 | + for inp, value in feeds.items(): |
| 19 | + if isinstance(inp, str): |
| 20 | + inputs[inp] = value |
| 21 | + else: |
| 22 | + inputs[inp.name] = value |
| 23 | + np.save(inputs_path, inputs) |
| 24 | + print("Saved inputs to {}".format(inputs_path)) |
| 25 | + |
| 26 | + # save graph and weights |
| 27 | + from tensorflow.saved_model import simple_save |
| 28 | + simple_save(sess, saved_model, |
| 29 | + {n: i for n,i in zip(inputs.keys(), feeds.keys())}, |
| 30 | + {op.name: op for op in outputs}) |
| 31 | + print("Saved model to {}".format(saved_model)) |
| 32 | + |
| 33 | + # generate config |
| 34 | + pretrained_model_yaml = ''' |
| 35 | +{}: |
| 36 | + model: ./saved_model |
| 37 | + model_type: saved_model |
| 38 | + input_get: get_ramp |
| 39 | +'''.format(model_name) |
| 40 | + pretrained_model_yaml += " inputs:\n" |
| 41 | + for inp, _ in inputs.items(): |
| 42 | + pretrained_model_yaml += " \"{}\": np.array(np.load(\"./inputs.npy\")[()][\"{}\"])\n".format( |
| 43 | + inp, inp |
| 44 | + ) |
| 45 | + outputs = [op.name for op in outputs] |
| 46 | + pretrained_model_yaml += " outputs:\n" |
| 47 | + for out in outputs: |
| 48 | + pretrained_model_yaml += " - {}\n".format(out) |
| 49 | + with open(pretrained_model_yaml_path, "w") as f: |
| 50 | + f.write(pretrained_model_yaml) |
| 51 | + print("Saved pretrained model yaml to {}".format(pretrained_model_yaml_path)) |
| 52 | + print("=========================================================") |
| 53 | + except Exception as ex: |
| 54 | + print("Error: {}".format(ex)) |
| 55 | + |
| 56 | + |
| 57 | +def test(): |
| 58 | + x_val = np.random.rand(5, 20).astype(np.float32) |
| 59 | + y_val = np.random.rand(20, 10).astype(np.float32) |
| 60 | + x = tf.placeholder(tf.float32, x_val.shape, name="x") |
| 61 | + y = tf.placeholder(tf.float32, y_val.shape, name="y") |
| 62 | + z = tf.matmul(x, y) |
| 63 | + w = tf.get_variable("weight", [5, 10], dtype=tf.float32) |
| 64 | + init = tf.global_variables_initializer() |
| 65 | + outputs = [z + w] |
| 66 | + feeds = {x: x_val, y: y_val} |
| 67 | + with tf.Session() as sess: |
| 68 | + sess.run(init) |
| 69 | + out = sess.run(outputs, feeds) |
| 70 | + # NOTE: Put below snippet after the LAST testing step |
| 71 | + save_pretrained_model(sess, outputs, feeds, "./tests", model_name="test") |
| 72 | + |
| 73 | + |
| 74 | +if __name__ == "__main__": |
| 75 | + test() |
0 commit comments