Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 6d572c3

Browse files
Seppo Enarviafrozenator
Seppo Enarvi
authored andcommitted
Create an integer problem_0_steps variable. (#1273)
1 parent edd7fb2 commit 6d572c3

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tensor2tensor/bin/t2t_avg_all.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def main(_):
6363
var_list = tf.contrib.framework.list_variables(model.filename)
6464
avg_values = {}
6565
for (name, shape) in var_list:
66-
if not name.startswith("global_step"):
66+
if not (name.startswith("global_step") or
67+
name.startswith("train_stats/")):
6768
avg_values[name] = np.zeros(shape)
6869
models_processed += 1
6970

@@ -88,6 +89,8 @@ def main(_):
8889
"global_step",
8990
initializer=tf.constant(model.steps, dtype=tf.int64),
9091
trainable=False)
92+
with tf.variable_scope("train_stats"):
93+
tf.get_variable("problem_0_steps", initializer=0, trainable=False)
9194
saver = tf.train.Saver(tf.global_variables())
9295

9396
tf.logging.info("Running session for %s" % (out_file))

0 commit comments

Comments
 (0)