Skip to content

Commit 4b5bbcb

Browse files
author
Dan
committed
Attempt to fix head_test
1 parent d8e9166 commit 4b5bbcb

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

tensorflow/contrib/learn/python/learn/estimators/head.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -872,18 +872,17 @@ def _centered_bias_step(centered_bias, logits_dimension, labels, loss_fn):
872872
"""Creates and returns training op for centered bias."""
873873
if (logits_dimension is None) or (logits_dimension < 1):
874874
raise ValueError("Invalid logits_dimension %s." % logits_dimension)
875-
with ops.name_scope(None, "centered_bias_step", (labels,)) as name:
876-
batch_size = array_ops.shape(labels)[0]
877-
logits = array_ops.reshape(
878-
array_ops.tile(centered_bias, (batch_size,)),
879-
(batch_size, logits_dimension))
880-
with ops.name_scope(None, "centered_bias", (labels, logits)):
881-
centered_bias_loss = math_ops.reduce_mean(
882-
loss_fn(logits, labels), name="training_loss")
883-
# Learn central bias by an optimizer. 0.1 is a convervative lr for a
884-
# single variable.
885-
return training.AdagradOptimizer(0.1).minimize(
886-
centered_bias_loss, var_list=(centered_bias,), name=name)
875+
batch_size = array_ops.shape(labels)[0]
876+
logits = array_ops.reshape(
877+
array_ops.tile(centered_bias, (batch_size,)),
878+
(batch_size, logits_dimension))
879+
with ops.name_scope(None, "centered_bias", (labels, logits)):
880+
centered_bias_loss = math_ops.reduce_mean(
881+
loss_fn(logits, labels), name="training_loss")
882+
# Learn central bias by an optimizer. 0.1 is a convervative lr for a
883+
# single variable.
884+
return training.AdagradOptimizer(0.1).minimize(
885+
centered_bias_loss, var_list=(centered_bias,), name="centered_bias_step")
887886

888887

889888
def _head_prefixed(head_name, val):
@@ -930,11 +929,14 @@ def _train_op(
930929
loss, labels, train_op_fn, centered_bias=None, logits_dimension=None,
931930
loss_fn=None):
932931
"""Returns op for the training step."""
932+
if centered_bias is not None:
933+
centered_bias_step = _centered_bias_step(
934+
centered_bias, logits_dimension, labels, loss_fn)
935+
else:
936+
centered_bias_step = None
933937
with ops.name_scope(None, "train_op", (loss, labels)):
934938
train_op = train_op_fn(loss)
935-
if centered_bias is not None:
936-
centered_bias_step = _centered_bias_step(
937-
centered_bias, logits_dimension, labels, loss_fn)
939+
if centered_bias_step is not None:
938940
train_op = control_flow_ops.group(train_op, centered_bias_step)
939941
return train_op
940942

tensorflow/contrib/learn/python/learn/estimators/head_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def testRegressionWithCenteredBias(self):
9595
self._assert_metrics(model_fn_ops)
9696
_assert_variables(self, expected_global=(
9797
"centered_bias_weight:0",
98-
"train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
98+
"centered_bias_weight/Adagrad:0",
9999
), expected_trainable=(
100100
"centered_bias_weight:0",
101101
))
@@ -166,7 +166,7 @@ def testMultiLabelWithCenteredBias(self):
166166
self._assert_metrics(model_fn_ops)
167167
_assert_variables(self, expected_global=(
168168
"centered_bias_weight:0",
169-
"train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
169+
"centered_bias_weight/Adagrad:0",
170170
), expected_trainable=(
171171
"centered_bias_weight:0",
172172
))
@@ -251,7 +251,7 @@ def testBinaryClassificationWithCenteredBias(self):
251251
self._assert_binary_metrics(model_fn_ops)
252252
_assert_variables(self, expected_global=(
253253
"centered_bias_weight:0",
254-
"train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
254+
"centered_bias_weight/Adagrad:0",
255255
), expected_trainable=(
256256
"centered_bias_weight:0",
257257
))
@@ -369,7 +369,7 @@ def testBinarySVMWithCenteredBias(self):
369369
self._assert_metrics(model_fn_ops)
370370
_assert_variables(self, expected_global=(
371371
"centered_bias_weight:0",
372-
"train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
372+
"centered_bias_weight/Adagrad:0",
373373
), expected_trainable=(
374374
"centered_bias_weight:0",
375375
))

0 commit comments

Comments
 (0)