Skip to content

Commit 585df7d

Browse files
Revert "Reduce usage of tf.contrib.layers (tensorflow#1350)"
This reverts commit 57a9720.
1 parent 1db6bf5 commit 585df7d

File tree

12 files changed

+58
-38
lines changed

12 files changed

+58
-38
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3577,21 +3577,21 @@ def cyclegan_upsample(net, num_outputs, stride, method="conv2d_transpose"):
35773577
net = tf.image.resize_nearest_neighbor(
35783578
net, [stride[0] * height, stride[1] * width])
35793579
net = tf.pad(net, spatial_pad_1, "REFLECT")
3580-
net = tf.layers.conv2d(
3581-
net, num_outputs, (3, 3), activation=tf.nn.relu)
3580+
net = tf.contrib.layers.conv2d(
3581+
net, num_outputs, kernel_size=[3, 3], padding="valid")
35823582
elif method == "bilinear_upsample_conv":
35833583
net = tf.image.resize_bilinear(net,
35843584
[stride[0] * height, stride[1] * width])
35853585
net = tf.pad(net, spatial_pad_1, "REFLECT")
3586-
net = tf.layers.conv2d(
3587-
net, num_outputs, (3, 3), activation=tf.nn.relu)
3586+
net = tf.contrib.layers.conv2d(
3587+
net, num_outputs, kernel_size=[3, 3], padding="valid")
35883588
elif method == "conv2d_transpose":
35893589
# This corrects 1 pixel offset for images with even width and height.
35903590
# conv2d is left aligned and conv2d_transpose is right aligned for even
35913591
# sized images (while doing "SAME" padding).
35923592
# Note: This doesn"t reflect actual model in paper.
3593-
net = tf.layers.conv2d_transpose(
3594-
net, num_outputs, (3, 3), strides=stride, activation=tf.nn.relu)
3593+
net = tf.contrib.layers.conv2d_transpose(
3594+
net, num_outputs, kernel_size=[3, 3], stride=stride, padding="valid")
35953595
net = net[:, 1:, 1:, :]
35963596
else:
35973597
raise ValueError("Unknown method: [%s]" % method)

tensor2tensor/layers/common_layers_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,8 @@ def testCycleGANUpsampleConv2dTranspose(self):
685685
num_channels = 3
686686
output_filters = 10
687687
stride = [2, 3] # we want height to be x2 and width to be x3
688-
random_input = tf.convert_to_tensor(
689-
np.random.rand(batch, height, width, num_channels), dtype=tf.float32)
688+
random_input = np.random.rand(batch, height, width, num_channels).astype(
689+
np.float32)
690690

691691
# conv2d_transpose is a little tricky.
692692
# height_new = (height_old - 1) * stride + kernel - 2*padding - correction

tensor2tensor/models/research/rl.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,8 @@ def feed_forward_gaussian_fun(action_space, config, observations):
451451
if not isinstance(action_space, gym.spaces.box.Box):
452452
raise ValueError("Expecting continuous action space.")
453453

454-
mean_weights_initializer = tf.initializers.variance_scaling(
455-
scale=config.init_mean_factor)
454+
mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer(
455+
factor=config.init_mean_factor)
456456
logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10)
457457

458458
flat_observations = tf.reshape(observations, [
@@ -463,10 +463,10 @@ def feed_forward_gaussian_fun(action_space, config, observations):
463463
with tf.variable_scope("policy"):
464464
x = flat_observations
465465
for size in config.policy_layers:
466-
x = tf.layers.dense(x, size, activation=tf.nn.relu)
467-
mean = tf.layers.dense(
468-
x, action_space.shape[0], activation=tf.tanh,
469-
kernel_initializer=mean_weights_initializer)
466+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
467+
mean = tf.contrib.layers.fully_connected(
468+
x, action_space.shape[0], tf.tanh,
469+
weights_initializer=mean_weights_initializer)
470470
logstd = tf.get_variable(
471471
"logstd", mean.shape[2:], tf.float32, logstd_initializer)
472472
logstd = tf.tile(
@@ -475,8 +475,8 @@ def feed_forward_gaussian_fun(action_space, config, observations):
475475
with tf.variable_scope("value"):
476476
x = flat_observations
477477
for size in config.value_layers:
478-
x = tf.layers.dense(x, size, activation=tf.nn.relu)
479-
value = tf.layers.dense(x, 1)[..., 0]
478+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
479+
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
480480
mean = tf.check_numerics(mean, "mean")
481481
logstd = tf.check_numerics(logstd, "logstd")
482482
value = tf.check_numerics(value, "value")
@@ -505,14 +505,16 @@ def body(self, features):
505505
with tf.variable_scope("policy"):
506506
x = flat_observations
507507
for size in self.hparams.policy_layers:
508-
x = tf.layers.dense(x, size, activation=tf.nn.relu)
509-
logits = tf.layers.dense(x, self.hparams.problem.num_actions)
508+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
509+
logits = tf.contrib.layers.fully_connected(
510+
x, self.hparams.problem.num_actions, activation_fn=None
511+
)
510512
logits = tf.expand_dims(logits, axis=1)
511513
with tf.variable_scope("value"):
512514
x = flat_observations
513515
for size in self.hparams.value_layers:
514-
x = tf.layers.dense(x, size, activation=tf.nn.relu)
515-
value = tf.layers.dense(x, 1)
516+
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
517+
value = tf.contrib.layers.fully_connected(x, 1, None)
516518
logits = clip_logits(logits, self.hparams)
517519
return {"target_policy": logits, "target_value": value}
518520

@@ -529,22 +531,23 @@ def body(self, features):
529531
dropout = getattr(self.hparams, "dropout_ppo", 0.0)
530532
with tf.variable_scope("feed_forward_cnn_small"):
531533
x = tf.cast(x, tf.float32) / 255.0
532-
x = tf.layers.conv2d(x, 32, (5, 5), strides=(2, 2),
533-
activation=tf.nn.relu, padding="same")
534-
x = tf.layers.conv2d(x, 32, (5, 5), strides=(2, 2),
535-
activation=tf.nn.relu, padding="same")
534+
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2],
535+
activation_fn=tf.nn.relu, padding="SAME")
536+
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2],
537+
activation_fn=tf.nn.relu, padding="SAME")
536538

537539
flat_x = tf.layers.flatten(x)
538540
flat_x = tf.layers.dropout(flat_x, rate=dropout)
539-
x = tf.layers.dense(flat_x, 128, activation=tf.nn.relu)
541+
x = tf.contrib.layers.fully_connected(flat_x, 128, tf.nn.relu)
540542

541543
logits = tf.layers.dense(
542544
x, self.hparams.problem.num_actions, name="dense2"
543545
)
544546
logits = clip_logits(logits, self.hparams)
545547
logits = tf.expand_dims(logits, axis=1)
546548

547-
value = tf.layers.dense(x, 1)
549+
value = tf.contrib.layers.fully_connected(
550+
x, 1, activation_fn=None)
548551
return {"target_policy": logits, "target_value": value}
549552

550553

@@ -597,12 +600,15 @@ def body(self, features):
597600
with tf.variable_scope("dense_bitwise"):
598601
flat_x = discretization.int_to_bit_embed(flat_x, 8, 32)
599602

600-
x = tf.layers.dense(flat_x, 256, activation=tf.nn.relu)
601-
x = tf.layers.dense(flat_x, 128, activation=tf.nn.relu)
603+
x = tf.contrib.layers.fully_connected(flat_x, 256, tf.nn.relu)
604+
x = tf.contrib.layers.fully_connected(flat_x, 128, tf.nn.relu)
602605

603-
logits = tf.layers.dense(x, self.hparams.problem.num_actions)
606+
logits = tf.contrib.layers.fully_connected(
607+
x, self.hparams.problem.num_actions, activation_fn=None
608+
)
604609

605-
value = tf.layers.dense(x, 1)[..., 0]
610+
value = tf.contrib.layers.fully_connected(
611+
x, 1, activation_fn=None)[..., 0]
606612

607613
return {"target_policy": logits, "target_value": value}
608614

tensor2tensor/models/research/transformer_vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def __init__(self, *args, **kwargs):
587587
self._hparams.num_residuals, self._hparams.num_blocks,
588588
self._hparams.hidden_size, block_dim
589589
],
590-
initializer=tf.initializers.glorot_uniform(),
590+
initializer=tf.contrib.layers.xavier_initializer(),
591591
trainable=self._hparams.trainable_projections)
592592

593593
self._hparams.bottleneck = functools.partial(

tensor2tensor/models/revnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def wrapped_partial(fn, *args, **kwargs):
4949
return wrapped
5050

5151

52-
conv_initializer = tf.initializers.variance_scaling(
53-
scale=2.0, mode='fan_out')
52+
conv_initializer = tf.contrib.layers.variance_scaling_initializer(
53+
factor=2.0, mode='FAN_OUT')
5454

5555
CONFIG = {'2d': {'conv': wrapped_partial(
5656
tf.layers.conv2d, kernel_initializer=conv_initializer),

tensor2tensor/models/video/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
import tensorflow as tf
3333

3434

35+
tfl = tf.layers
36+
tfcl = tf.contrib.layers
37+
38+
3539
def flat_lists(list_of_lists):
3640
return [x for l in list_of_lists for x in l]
3741

tensor2tensor/models/video/basic_deterministic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
import tensorflow as tf
3030

3131

32+
tfl = tf.layers
33+
tfcl = tf.contrib.layers
34+
35+
3236
@registry.register_model
3337
class NextFrameBasicDeterministic(base.NextFrameBase):
3438
"""Basic next-frame model, may take actions and predict rewards too."""

tensor2tensor/models/video/basic_recurrent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
from tensor2tensor.models.video import basic_stochastic
2424
from tensor2tensor.utils import registry
2525

26+
import tensorflow as tf
27+
28+
29+
tfl = tf.layers
30+
tfcl = tf.contrib.layers
31+
2632

2733
@registry.register_model
2834
class NextFrameBasicRecurrent(

tensor2tensor/models/video/emily.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def construct_model(self, images, actions, rewards):
276276
for i, image in enumerate(images):
277277
with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
278278
enc, skips = self.encoder(image, rnn_size, has_batchnorm=has_batchnorm)
279-
enc = tfl.flatten(enc)
279+
enc = tfcl.flatten(enc)
280280
enc_images.append(enc)
281281
enc_skips.append(skips)
282282

tensor2tensor/models/video/sv2p.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def construct_predictive_tower(
314314

315315
if self.hparams.model_options == "CDNA":
316316
# cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
317-
cdna_input = tfl.flatten(hidden5)
317+
cdna_input = tfcl.flatten(hidden5)
318318
transformed += common_video.cdna_transformation(
319319
input_image, cdna_input, num_masks, int(color_channels),
320320
self.hparams.dna_kernel_size, self.hparams.relu_shift)

tensor2tensor/models/video/svg_lp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def construct_model(self, images, actions, rewards):
180180
for i, image in enumerate(images):
181181
with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
182182
enc, skips = self.encoder(image, g_dim, has_batchnorm=has_batchnorm)
183-
enc = tfl.flatten(enc)
183+
enc = tfcl.flatten(enc)
184184
enc_images.append(enc)
185185
enc_skips.append(skips)
186186

@@ -199,7 +199,7 @@ def construct_model(self, images, actions, rewards):
199199
h_current = enc_images[i-1]
200200
else:
201201
h_current, _ = self.encoder(gen_images[-1], g_dim)
202-
h_current = tfl.flatten(h_current)
202+
h_current = tfcl.flatten(h_current)
203203

204204
# target encoding
205205
h_target = enc_images[i]

tensor2tensor/utils/optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,6 @@ def get_variable_initializer(hparams):
328328
return tf.variance_scaling_initializer(
329329
hparams.initializer_gain, mode="fan_avg", distribution="uniform")
330330
elif hparams.initializer == "xavier":
331-
return tf.initializers.glorot_uniform()
331+
return tf.contrib.layers.xavier_initializer()
332332
else:
333333
raise ValueError("Unrecognized initializer: %s" % hparams.initializer)

0 commit comments

Comments
 (0)