Skip to content

Commit

Permalink
Merge pull request #7 from dplarson/tensorflow_v1.0.0
Browse files Browse the repository at this point in the history
Revise for TensorFlow 1.0.x compatibility
  • Loading branch information
mdeff authored Mar 11, 2017
2 parents 959acb5 + 82253b1 commit 44624ab
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def fit(self, train_data, train_labels, val_data, val_labels):
t_process, t_wall = time.process_time(), time.time()
sess = tf.Session(graph=self.graph)
shutil.rmtree(self._get_path('summaries'), ignore_errors=True)
writer = tf.train.SummaryWriter(self._get_path('summaries'), self.graph)
writer = tf.summary.FileWriter(self._get_path('summaries'), self.graph)
shutil.rmtree(self._get_path('checkpoints'), ignore_errors=True)
os.makedirs(self._get_path('checkpoints'))
path = os.path.join(self._get_path('checkpoints'), 'model')
Expand Down Expand Up @@ -165,10 +165,10 @@ def build_graph(self, M_0):
self.op_prediction = self.prediction(op_logits)

# Initialize variables, i.e. weights and biases.
self.op_init = tf.initialize_all_variables()
self.op_init = tf.global_variables_initializer()

# Summaries for TensorBoard and Save for model parameters.
self.op_summary = tf.merge_all_summaries()
self.op_summary = tf.summary.merge_all()
self.op_saver = tf.train.Saver(max_to_keep=5)

self.graph.finalize()
Expand Down Expand Up @@ -199,30 +199,30 @@ def probabilities(self, logits):
def prediction(self, logits):
"""Return the predicted classes."""
with tf.name_scope('prediction'):
prediction = tf.argmax(logits, dimension=1)
prediction = tf.argmax(logits, axis=1)
return prediction

def loss(self, logits, labels, regularization):
"""Adds to the inference model the layers required to generate loss."""
with tf.name_scope('loss'):
with tf.name_scope('cross_entropy'):
labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
cross_entropy = tf.reduce_mean(cross_entropy)
with tf.name_scope('regularization'):
regularization *= tf.add_n(self.regularizers)
loss = cross_entropy + regularization

# Summaries for TensorBoard.
tf.scalar_summary('loss/cross_entropy', cross_entropy)
tf.scalar_summary('loss/regularization', regularization)
tf.scalar_summary('loss/total', loss)
tf.summary.scalar('loss/cross_entropy', cross_entropy)
tf.summary.scalar('loss/regularization', regularization)
tf.summary.scalar('loss/total', loss)
with tf.name_scope('averages'):
averages = tf.train.ExponentialMovingAverage(0.9)
op_averages = averages.apply([cross_entropy, regularization, loss])
tf.scalar_summary('loss/avg/cross_entropy', averages.average(cross_entropy))
tf.scalar_summary('loss/avg/regularization', averages.average(regularization))
tf.scalar_summary('loss/avg/total', averages.average(loss))
tf.summary.scalar('loss/avg/cross_entropy', averages.average(cross_entropy))
tf.summary.scalar('loss/avg/regularization', averages.average(regularization))
tf.summary.scalar('loss/avg/total', averages.average(loss))
with tf.control_dependencies([op_averages]):
loss_average = tf.identity(averages.average(loss), name='control')
return loss, loss_average
Expand All @@ -235,7 +235,7 @@ def training(self, loss, learning_rate, decay_steps, decay_rate=0.95, momentum=0
if decay_rate != 1:
learning_rate = tf.train.exponential_decay(
learning_rate, global_step, decay_steps, decay_rate, staircase=True)
tf.scalar_summary('learning_rate', learning_rate)
tf.summary.scalar('learning_rate', learning_rate)
# Optimizer.
if momentum == 0:
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
Expand All @@ -249,7 +249,7 @@ def training(self, loss, learning_rate, decay_steps, decay_rate=0.95, momentum=0
if grad is None:
print('warning: {} has no gradient'.format(var.op.name))
else:
tf.histogram_summary(var.op.name + '/gradients', grad)
tf.summary.histogram(var.op.name + '/gradients', grad)
# The op return the learning rate.
with tf.control_dependencies([op_gradients]):
op_train = tf.identity(learning_rate, name='control')
Expand All @@ -274,15 +274,15 @@ def _weight_variable(self, shape, regularization=True):
var = tf.get_variable('weights', shape, tf.float32, initializer=initial)
if regularization:
self.regularizers.append(tf.nn.l2_loss(var))
tf.histogram_summary(var.op.name, var)
tf.summary.histogram(var.op.name, var)
return var

def _bias_variable(self, shape, regularization=True):
initial = tf.constant_initializer(0.1)
var = tf.get_variable('bias', shape, tf.float32, initializer=initial)
if regularization:
self.regularizers.append(tf.nn.l2_loss(var))
tf.histogram_summary(var.op.name, var)
tf.summary.histogram(var.op.name, var)
return var

def _conv2d(self, x, W):
Expand Down Expand Up @@ -360,12 +360,12 @@ def _inference(self, x, dropout):
Wimg = self._weight_variable([int(NFEATURES/2), self.F, 1])
W = tf.complex(Wreal, Wimg)
xf = xf[:int(NFEATURES/2), :, :]
yf = tf.batch_matmul(W, xf) # for each feature
yf = tf.concat(0, [yf, tf.conj(yf)])
yf = tf.matmul(W, xf) # for each feature
yf = tf.concat(values=[yf, tf.conj(yf)], axis=0)
yf = tf.transpose(yf) # NSAMPLES x NFILTERS x NFEATURES
yf_2d = tf.reshape(yf, [-1, 28, 28])
# Transform back to spatial domain
y_2d = tf.batch_ifft2d(yf_2d)
y_2d = tf.ifft2d(yf_2d)
y_2d = tf.real(y_2d)
y = tf.reshape(y_2d, [-1, self.F, NFEATURES])
# Bias and non-linearity
Expand Down Expand Up @@ -401,7 +401,7 @@ def _inference(self, x, dropout):
xf = tf.transpose(xf) # NFEATURES x 1 x NSAMPLES
# Filter
W = self._weight_variable([NFEATURES, self.F, 1])
yf = tf.batch_matmul(W, xf) # for each feature
yf = tf.matmul(W, xf) # for each feature
yf = tf.transpose(yf) # NSAMPLES x NFILTERS x NFEATURES
yf = tf.reshape(yf, [-1, NFEATURES])
# Transform back to graph domain
Expand Down Expand Up @@ -632,7 +632,7 @@ def _inference(self, x, dropout):
xt = tf.expand_dims(xt0, 0) # 1 x M x N
def concat(xt, x):
x = tf.expand_dims(x, 0) # 1 x M x N
return tf.concat(0, [xt, x]) # K x M x N
return tf.concat(values=[xt, x], axis=0) # K x M x N
if self.K > 1:
xt1 = tf.sparse_tensor_dense_matmul(self.L, xt0)
xt = concat(xt, xt1)
Expand Down Expand Up @@ -813,7 +813,7 @@ def filter_in_fourier(self, x, L, Fout, K, U, W):
x = tf.matmul(U, x) # M x Fin*N
x = tf.reshape(x, [M, Fin, N]) # M x Fin x N
# Filter
x = tf.batch_matmul(W, x) # for each feature
x = tf.matmul(W, x) # for each feature
x = tf.transpose(x) # N x Fout x M
x = tf.reshape(x, [N*Fout, M]) # N*Fout x M
# Transform back to graph domain
Expand Down Expand Up @@ -893,7 +893,7 @@ def chebyshev5(self, x, L, Fout, K):
x = tf.expand_dims(x0, 0) # 1 x M x Fin*N
def concat(x, x_):
x_ = tf.expand_dims(x_, 0) # 1 x M x Fin*N
return tf.concat(0, [x, x_]) # K x M x Fin*N
return tf.concat(values=[x, x_], axis=0) # K x M x Fin*N
if K > 1:
x1 = tf.sparse_tensor_dense_matmul(L, x0)
x = concat(x, x1)
Expand Down

0 comments on commit 44624ab

Please sign in to comment.