Skip to content

Commit

Permalink
TF deprecation warning fix for eval.py.wq
Browse files Browse the repository at this point in the history
  • Loading branch information
XericZephyr committed Aug 20, 2019
1 parent 6006a3b commit 541bf0e
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
import losses
import readers
import tensorflow as tf
from tensorflow import app
from tensorflow import flags
from tensorflow import gfile
from tensorflow.python.lib.io import file_io
import utils
import video_level_models
Expand Down Expand Up @@ -88,7 +86,7 @@ def get_input_evaluation_tensors(reader,
"""
logging.info("Using batch size of %d for evaluation.", batch_size)
with tf.name_scope("eval_input"):
files = gfile.Glob(data_pattern)
files = tf.io.gfile.glob(data_pattern)
if not files:
raise IOError("Unable to find the evaluation files.")
logging.info("number of evaluation files: %d", len(files))
Expand Down Expand Up @@ -131,22 +129,22 @@ def build_graph(reader,
model_input_raw = input_data_dict["video_matrix"]
labels_batch = input_data_dict["labels"]
num_frames = input_data_dict["num_frames"]
tf.summary.histogram("model_input_raw", model_input_raw)
tf.compat.v1.summary.histogram("model_input_raw", model_input_raw)

feature_dim = len(model_input_raw.get_shape()) - 1

# Normalize input features.
model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)

with tf.variable_scope("tower"):
with tf.compat.v1.variable_scope("tower"):
result = model.create_model(
model_input,
num_frames=num_frames,
vocab_size=reader.num_classes,
labels=labels_batch,
is_training=False)
predictions = result["predictions"]
tf.summary.histogram("model_activations", predictions)
tf.compat.v1.summary.histogram("model_activations", predictions)
if "loss" in result.keys():
label_loss = result["loss"]
else:
Expand All @@ -162,7 +160,7 @@ def build_graph(reader,
tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
if FLAGS.segment_labels:
tf.add_to_collection("label_weights", input_data_dict["label_weights"])
tf.add_to_collection("summary_op", tf.summary.merge_all())
tf.add_to_collection("summary_op", tf.compat.v1.summary.merge_all())


def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
Expand Down Expand Up @@ -276,7 +274,7 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics,

def evaluate():
"""Starts main evaluation loop."""
tf.set_random_seed(0) # for reproducibility
tf.compat.v1.set_random_seed(0) # for reproducibility

# Write json of flags
model_flags_path = os.path.join(FLAGS.train_dir, "model_flags.json")
Expand Down Expand Up @@ -327,8 +325,8 @@ def evaluate():
if FLAGS.segment_labels:
fetches["label_weights"] = tf.get_collection("label_weights")[0]

saver = tf.train.Saver(tf.global_variables())
summary_writer = tf.summary.FileWriter(
saver = tf.compat.v1.train.Saver(tf.global_variables())
summary_writer = tf.compat.v1.summary.FileWriter(
os.path.join(FLAGS.train_dir, "eval"), graph=tf.get_default_graph())

evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k,
Expand All @@ -343,10 +341,10 @@ def evaluate():


def main(unused_argv):
logging.set_verbosity(tf.logging.INFO)
logging.set_verbosity(logging.INFO)
logging.info("tensorflow version: %s", tf.__version__)
evaluate()


if __name__ == "__main__":
app.run()
tf.compat.v1.app.run()

0 comments on commit 541bf0e

Please sign in to comment.