Skip to content

Commit

Permalink
Fix more deprecation warning in eval.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
XericZephyr committed Aug 20, 2019
1 parent 4d36299 commit f51dc7d
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,18 @@ def build_graph(reader,
else:
label_loss = label_loss_fn.calculate_loss(predictions, labels_batch)

tf.add_to_collection("global_step", global_step)
tf.add_to_collection("loss", label_loss)
tf.add_to_collection("predictions", predictions)
tf.add_to_collection("input_batch", model_input)
tf.add_to_collection("input_batch_raw", model_input_raw)
tf.add_to_collection("video_id_batch", video_id_batch)
tf.add_to_collection("num_frames", num_frames)
tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
tf.compat.v1.add_to_collection("global_step", global_step)
tf.compat.v1.add_to_collection("loss", label_loss)
tf.compat.v1.add_to_collection("predictions", predictions)
tf.compat.v1.add_to_collection("input_batch", model_input)
tf.compat.v1.add_to_collection("input_batch_raw", model_input_raw)
tf.compat.v1.add_to_collection("video_id_batch", video_id_batch)
tf.compat.v1.add_to_collection("num_frames", num_frames)
tf.compat.v1.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
if FLAGS.segment_labels:
tf.add_to_collection("label_weights", input_data_dict["label_weights"])
tf.add_to_collection("summary_op", tf.compat.v1.summary.merge_all())
tf.compat.v1.add_to_collection("label_weights",
input_data_dict["label_weights"])
tf.compat.v1.add_to_collection("summary_op", tf.compat.v1.summary.merge_all())


def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
Expand Down Expand Up @@ -211,7 +212,7 @@ def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
coord = tf.train.Coordinator()
try:
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
for qr in tf.compat.v1.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(
qr.create_threads(sess, coord=coord, daemon=True, start=True))
logging.info("enter eval_once loop global_step_val = %s. ",
Expand Down Expand Up @@ -316,18 +317,19 @@ def evaluate():

# A dict of tensors to be run in Session.
fetches = {
"video_id": tf.get_collection("video_id_batch")[0],
"predictions": tf.get_collection("predictions")[0],
"labels": tf.get_collection("labels")[0],
"loss": tf.get_collection("loss")[0],
"summary": tf.get_collection("summary_op")[0]
"video_id": tf.compat.v1.get_collection("video_id_batch")[0],
"predictions": tf.compat.v1.get_collection("predictions")[0],
"labels": tf.compat.v1.get_collection("labels")[0],
"loss": tf.compat.v1.get_collection("loss")[0],
"summary": tf.compat.v1.get_collection("summary_op")[0]
}
if FLAGS.segment_labels:
fetches["label_weights"] = tf.get_collection("label_weights")[0]
fetches["label_weights"] = tf.compat.v1.get_collection("label_weights")[0]

saver = tf.compat.v1.train.Saver(tf.global_variables())
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
summary_writer = tf.compat.v1.summary.FileWriter(
os.path.join(FLAGS.train_dir, "eval"), graph=tf.get_default_graph())
os.path.join(FLAGS.train_dir, "eval"),
graph=tf.compat.v1.get_default_graph())

evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k,
None)
Expand Down

0 comments on commit f51dc7d

Please sign in to comment.