|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +"""Make simple test model in all tensorflow formats.""" |
| 5 | + |
| 6 | +from __future__ import division |
| 7 | +from __future__ import print_function |
| 8 | +from __future__ import unicode_literals |
| 9 | + |
| 10 | +import os |
| 11 | +import unittest |
| 12 | +from collections import namedtuple |
| 13 | + |
| 14 | +import graphviz as gv |
| 15 | +from onnx import TensorProto |
| 16 | +from onnx import helper |
| 17 | + |
| 18 | +import tensorflow as tf |
| 19 | +from tensorflow.python.framework.graph_util import convert_variables_to_constants |
| 20 | +import numpy as np |
| 21 | + |
| 22 | +import os |
| 23 | + |
| 24 | + |
| 25 | +# pylint: disable=missing-docstring |
| 26 | + |
| 27 | +# Parameters |
| 28 | +learning_rate = 0.02 |
| 29 | +training_epochs = 100 |
| 30 | + |
| 31 | +# Training Data |
| 32 | +train_X = np.array( |
| 33 | + [3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167, 7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1]) |
| 34 | +train_Y = np.array( |
| 35 | + [1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3]) |
| 36 | +test_X = np.array([6.83, 4.668, 8.9, 7.91, 5.7, 8.7, 3.1, 2.1]) |
| 37 | +test_Y = np.array([1.84, 2.273, 3.2, 2.831, 2.92, 3.24, 1.35, 1.03]) |
| 38 | + |
| 39 | +def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=True): |
| 40 | + """Freezes the state of a session into a pruned computation graph.""" |
| 41 | + output_names = [i.replace(":0", "") for i in output_names] |
| 42 | + graph = sess.graph |
| 43 | + with graph.as_default(): |
| 44 | + freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) |
| 45 | + output_names = output_names or [] |
| 46 | + output_names += [v.op.name for v in tf.global_variables()] |
| 47 | + input_graph_def = graph.as_graph_def() |
| 48 | + if clear_devices: |
| 49 | + for node in input_graph_def.node: |
| 50 | + node.device = "" |
| 51 | + frozen_graph = convert_variables_to_constants(sess, input_graph_def, |
| 52 | + output_names, freeze_var_names) |
| 53 | + return frozen_graph |
| 54 | + |
| 55 | +def train(model_path): |
| 56 | + n_samples = train_X.shape[0] |
| 57 | + |
| 58 | + # tf Graph Input |
| 59 | + X = tf.placeholder(tf.float32, name="X") |
| 60 | + Y = tf.placeholder(tf.float32, name="Y") |
| 61 | + |
| 62 | + # Set model weights |
| 63 | + W = tf.Variable(np.random.randn(), name="W") |
| 64 | + b = tf.Variable(np.random.randn(), name="b") |
| 65 | + |
| 66 | + pred = tf.add(tf.multiply(X, W), b) |
| 67 | + pred = tf.identity(pred, name="pred") |
| 68 | + cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples) |
| 69 | + |
| 70 | + optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) |
| 71 | + saver = tf.train.Saver() |
| 72 | + |
| 73 | + # Launch the graph |
| 74 | + with tf.Session() as sess: |
| 75 | + sess.run(tf.global_variables_initializer()) |
| 76 | + |
| 77 | + # Fit all training data |
| 78 | + for epoch in range(training_epochs): |
| 79 | + for (x, y) in zip(train_X, train_Y): |
| 80 | + sess.run(optimizer, feed_dict={X: x, Y: y}) |
| 81 | + training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y}) |
| 82 | + testing_cost = sess.run(cost, feed_dict={X: test_X, Y: test_Y}) |
| 83 | + print("train_cost={}, test_cost={}, diff={}" |
| 84 | + .format(training_cost, testing_cost, abs(training_cost - testing_cost))) |
| 85 | + |
| 86 | + p = os.path.abspath(os.path.join(model_path, "checkpoint")) |
| 87 | + os.makedirs(p, exist_ok=True) |
| 88 | + p = saver.save(sess, os.path.join(p, "model")) |
| 89 | + |
| 90 | + frozen_graph = freeze_session(sess, output_names=["pred:0"]) |
| 91 | + p = os.path.abspath(os.path.join(model_path, "graphdef")) |
| 92 | + tf.train.write_graph(frozen_graph, p, "frozen.pb", as_text=False) |
| 93 | + |
| 94 | + p = os.path.abspath(os.path.join(model_path, "saved_model")) |
| 95 | + tf.saved_model.simple_save(sess, p, inputs={"X": X}, outputs={"pred": pred}) |
| 96 | + |
| 97 | + |
| 98 | +train("models/regression") |
| 99 | + |
0 commit comments