@@ -872,18 +872,17 @@ def _centered_bias_step(centered_bias, logits_dimension, labels, loss_fn):
872
872
"""Creates and returns training op for centered bias."""
873
873
if (logits_dimension is None ) or (logits_dimension < 1 ):
874
874
raise ValueError ("Invalid logits_dimension %s." % logits_dimension )
875
- with ops .name_scope (None , "centered_bias_step" , (labels ,)) as name :
876
- batch_size = array_ops .shape (labels )[0 ]
877
- logits = array_ops .reshape (
878
- array_ops .tile (centered_bias , (batch_size ,)),
879
- (batch_size , logits_dimension ))
880
- with ops .name_scope (None , "centered_bias" , (labels , logits )):
881
- centered_bias_loss = math_ops .reduce_mean (
882
- loss_fn (logits , labels ), name = "training_loss" )
883
- # Learn central bias by an optimizer. 0.1 is a convervative lr for a
884
- # single variable.
885
- return training .AdagradOptimizer (0.1 ).minimize (
886
- centered_bias_loss , var_list = (centered_bias ,), name = name )
875
+ batch_size = array_ops .shape (labels )[0 ]
876
+ logits = array_ops .reshape (
877
+ array_ops .tile (centered_bias , (batch_size ,)),
878
+ (batch_size , logits_dimension ))
879
+ with ops .name_scope (None , "centered_bias" , (labels , logits )):
880
+ centered_bias_loss = math_ops .reduce_mean (
881
+ loss_fn (logits , labels ), name = "training_loss" )
882
+ # Learn central bias by an optimizer. 0.1 is a convervative lr for a
883
+ # single variable.
884
+ return training .AdagradOptimizer (0.1 ).minimize (
885
+ centered_bias_loss , var_list = (centered_bias ,), name = "centered_bias_step" )
887
886
888
887
889
888
def _head_prefixed (head_name , val ):
@@ -930,11 +929,14 @@ def _train_op(
930
929
loss , labels , train_op_fn , centered_bias = None , logits_dimension = None ,
931
930
loss_fn = None ):
932
931
"""Returns op for the training step."""
932
+ if centered_bias is not None :
933
+ centered_bias_step = _centered_bias_step (
934
+ centered_bias , logits_dimension , labels , loss_fn )
935
+ else :
936
+ centered_bias_step = None
933
937
with ops .name_scope (None , "train_op" , (loss , labels )):
934
938
train_op = train_op_fn (loss )
935
- if centered_bias is not None :
936
- centered_bias_step = _centered_bias_step (
937
- centered_bias , logits_dimension , labels , loss_fn )
939
+ if centered_bias_step is not None :
938
940
train_op = control_flow_ops .group (train_op , centered_bias_step )
939
941
return train_op
940
942
0 commit comments