Skip to content

Commit

Permalink
Implement savedmodel in sparse classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
tobegit3hub committed Apr 27, 2017
1 parent 0ca00b4 commit 27d7a9f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 7 deletions.
12 changes: 6 additions & 6 deletions dense_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,17 +528,17 @@ def inference(inputs, is_train=True):
accuracy = float(correct_label_number) / label_number

# Compute auc
expected_labels = np.array(inference_data_labels)
predict_labels = prediction_softmax[:, 0]
fpr, tpr, thresholds = metrics.roc_curve(expected_labels,
predict_labels,
pos_label=0)
y_true = np.array(inference_data_labels)
y_score = prediction_softmax[:, 1]
fpr, tpr, thresholds = metrics.roc_curve(y_true,
y_score,
pos_label=1)
auc = metrics.auc(fpr, tpr)
logging.info("[{}] Inference accuracy: {}, auc: {}".format(
end_time - start_time, accuracy, auc))

# Save result into the file
np.savetxt(inference_result_file_name, prediction, delimiter=",")
np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",")
logging.info("Save result to file: {}".format(
inference_result_file_name))

Expand Down
54 changes: 53 additions & 1 deletion sparse_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from sklearn import metrics
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils
from tensorflow.python.util import compat

# Define hyperparameters
flags = tf.app.flags
Expand Down Expand Up @@ -49,6 +55,8 @@
flags.DEFINE_integer("steps_to_validate", 10,
"Steps to validate and print state")
flags.DEFINE_string("mode", "train", "Support train, export, inference")
flags.DEFINE_string("saved_model_path", "./sparse_saved_model/",
"The path of the saved model")
flags.DEFINE_string("model_path", "./sparse_model/", "The path of the model")
flags.DEFINE_integer("model_version", 1, "The version of the model")
flags.DEFINE_string("inference_test_file", "./data/a8a_test.libsvm",
Expand Down Expand Up @@ -391,6 +399,50 @@ def inference(sparse_ids, sparse_values, is_train=True):
export_model(sess, saver, model_signature, FLAGS.model_path,
FLAGS.model_version)

elif MODE == "savedmodel":
if not restore_session_from_checkpoint(sess, saver, LATEST_CHECKPOINT):
logging.error("No checkpoint found, exit now")
exit(1)

logging.info("Export the saved model to {}".format(
FLAGS.saved_model_path))
export_path_base = FLAGS.saved_model_path
export_path = os.path.join(
compat.as_bytes(export_path_base),
compat.as_bytes(str(FLAGS.model_version)))

model_signature = signature_def_utils.build_signature_def(
inputs={
"keys": utils.build_tensor_info(keys_placeholder),
"indexs": utils.build_tensor_info(sparse_index),
"ids": utils.build_tensor_info(sparse_ids),
"values": utils.build_tensor_info(sparse_values),
"shape": utils.build_tensor_info(sparse_shape)
},
outputs={
"keys": utils.build_tensor_info(keys),
"softmax": utils.build_tensor_info(inference_softmax),
"prediction": utils.build_tensor_info(inference_op)
},
method_name=signature_constants.PREDICT_METHOD_NAME)

try:
builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
sess,
[tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
model_signature,
},
#legacy_init_op=legacy_init_op)
legacy_init_op=tf.group(tf.initialize_all_tables(),
name="legacy_init_op"))

builder.save()
except Exception as e:
logging.error("Fail to export saved model, exception: {}".format(e))

elif MODE == "inference":
if not restore_session_from_checkpoint(sess, saver, LATEST_CHECKPOINT):
logging.error("No checkpoint found, exit now")
Expand Down Expand Up @@ -446,7 +498,7 @@ def inference(sparse_ids, sparse_values, is_train=True):
end_time - start_time, accuracy, auc))

# Save result into the file
np.savetxt(inference_result_file_name, prediction, delimiter=",")
np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",")
logging.info("Save result to file: {}".format(
inference_result_file_name))

Expand Down

0 comments on commit 27d7a9f

Please sign in to comment.