Skip to content

Commit

Permalink
remove unused ema code
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Aug 30, 2018
1 parent bd1cf7d commit ae7e86b
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 19 deletions.
3 changes: 1 addition & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from datasets import rocstories
from analysis import rocstories as rocstories_analysis
from text_utils import TextEncoder
from utils import encode_dataset, flatten, iter_data, find_trainable_variables, get_ema_vars, convert_gradient_to_tensor, shape_list, ResultLogger, assign_to_gpu, average_grads, make_path
from utils import encode_dataset, flatten, iter_data, find_trainable_variables, convert_gradient_to_tensor, shape_list, ResultLogger, assign_to_gpu, average_grads, make_path

def gelu(x):
return 0.5*x*(1+tf.tanh(math.sqrt(2/math.pi)*(x+0.044715*tf.pow(x, 3))))
Expand Down Expand Up @@ -54,7 +54,6 @@ def norm(x, scope, axis=[-1]):
n_state = shape_list(x)[-1]
g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1))
b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0))
g, b = get_ema_vars(g, b)
return _norm(x, g, b, axis=axis)

def dropout(x, pdrop, train):
Expand Down
17 changes: 0 additions & 17 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,6 @@ def iter_data(*datas, n_batch=128, truncate=False, verbose=False, max_batches=fl
yield (d[i:i+n_batch] for d in datas)
n_batches += 1

def get_ema_if_exists(v, gvs):
name = v.name.split(':')[0]
ema_name = name+'/ExponentialMovingAverage:0'
ema_v = [v for v in gvs if v.name == ema_name]
if len(ema_v) == 0:
ema_v = [v]
return ema_v[0]

def get_ema_vars(*vs):
if tf.get_variable_scope().reuse:
gvs = tf.global_variables()
vs = [get_ema_if_exists(v, gvs) for v in vs]
if len(vs) == 1:
return vs[0]
else:
return vs

@function.Defun(
python_grad_func=lambda x, dy: tf.convert_to_tensor(dy),
shape_func=lambda op: [op.inputs[0].get_shape()])
Expand Down

0 comments on commit ae7e86b

Please sign in to comment.