Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Reduce usage of tf.contrib.layers #1350

Merged
merged 6 commits into from
Jan 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tensor2tensor/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3561,21 +3561,21 @@ def cyclegan_upsample(net, num_outputs, stride, method="conv2d_transpose"):
net = tf.image.resize_nearest_neighbor(
net, [stride[0] * height, stride[1] * width])
net = tf.pad(net, spatial_pad_1, "REFLECT")
net = tf.contrib.layers.conv2d(
net, num_outputs, kernel_size=[3, 3], padding="valid")
net = tf.layers.conv2d(
net, num_outputs, (3, 3), activation=tf.nn.relu)
elif method == "bilinear_upsample_conv":
net = tf.image.resize_bilinear(net,
[stride[0] * height, stride[1] * width])
net = tf.pad(net, spatial_pad_1, "REFLECT")
net = tf.contrib.layers.conv2d(
net, num_outputs, kernel_size=[3, 3], padding="valid")
net = tf.layers.conv2d(
net, num_outputs, (3, 3), activation=tf.nn.relu)
elif method == "conv2d_transpose":
# This corrects 1 pixel offset for images with even width and height.
# conv2d is left aligned and conv2d_transpose is right aligned for even
# sized images (while doing "SAME" padding).
# Note: This doesn"t reflect actual model in paper.
net = tf.contrib.layers.conv2d_transpose(
net, num_outputs, kernel_size=[3, 3], stride=stride, padding="valid")
net = tf.layers.conv2d_transpose(
net, num_outputs, (3, 3), strides=stride, activation=tf.nn.relu)
Copy link
Contributor Author

@lgeiger lgeiger Jan 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes the CycleGANUpsampleConv2dTranspose test fail with AttributeError: 'tuple' object has no attribute 'ndims'.

In the unit test cyclegan_upsample gets called with a numpy input, tensorflow propably tries to interpret it as a list, which makes the test fail. Using a tensor as input would fix the unit test.

# this fails
upsampled_output = common_layers.cyclegan_upsample(random_input,
    output_filters, stride, "conv2d_transpose")

# this works
upsampled_output = common_layers.cyclegan_upsample(tf.convert_to_tensor(random_input),
    output_filters, stride, "conv2d_transpose")

@afrozenator Is this a bug that should be reported to the Tensorflow team, or is this expected behavior with tf.layers?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just fix the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 fixed it

I just wanted to bring this to your attention since the failure doesn't happen in the other upsampling tests.

net = net[:, 1:, 1:, :]
else:
raise ValueError("Unknown method: [%s]" % method)
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/layers/common_layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,8 +685,8 @@ def testCycleGANUpsampleConv2dTranspose(self):
num_channels = 3
output_filters = 10
stride = [2, 3] # we want height to be x2 and width to be x3
random_input = np.random.rand(batch, height, width, num_channels).astype(
np.float32)
random_input = tf.convert_to_tensor(
np.random.rand(batch, height, width, num_channels), dtype=tf.float32)

# conv2d_transpose is a little tricky.
# height_new = (height_old - 1) * stride + kernel - 2*padding - correction
Expand Down
50 changes: 22 additions & 28 deletions tensor2tensor/models/research/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,8 @@ def feed_forward_gaussian_fun(action_space, config, observations):
if not isinstance(action_space, gym.spaces.box.Box):
raise ValueError("Expecting continuous action space.")

mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer(
factor=config.init_mean_factor)
mean_weights_initializer = tf.initializers.variance_scaling(
scale=config.init_mean_factor)
logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10)

flat_observations = tf.reshape(observations, [
Expand All @@ -410,10 +410,10 @@ def feed_forward_gaussian_fun(action_space, config, observations):
with tf.variable_scope("policy"):
x = flat_observations
for size in config.policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
mean = tf.contrib.layers.fully_connected(
x, action_space.shape[0], tf.tanh,
weights_initializer=mean_weights_initializer)
x = tf.layers.dense(x, size, activation=tf.nn.relu)
mean = tf.layers.dense(
x, action_space.shape[0], activation=tf.tanh,
kernel_initializer=mean_weights_initializer)
logstd = tf.get_variable(
"logstd", mean.shape[2:], tf.float32, logstd_initializer)
logstd = tf.tile(
Expand All @@ -422,8 +422,8 @@ def feed_forward_gaussian_fun(action_space, config, observations):
with tf.variable_scope("value"):
x = flat_observations
for size in config.value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
x = tf.layers.dense(x, size, activation=tf.nn.relu)
value = tf.layers.dense(x, 1)[..., 0]
mean = tf.check_numerics(mean, "mean")
logstd = tf.check_numerics(logstd, "logstd")
value = tf.check_numerics(value, "value")
Expand Down Expand Up @@ -452,16 +452,14 @@ def body(self, features):
with tf.variable_scope("policy"):
x = flat_observations
for size in self.hparams.policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
logits = tf.contrib.layers.fully_connected(
x, self.hparams.problem.num_actions, activation_fn=None
)
x = tf.layers.dense(x, size, activation=tf.nn.relu)
logits = tf.layers.dense(x, self.hparams.problem.num_actions)
logits = tf.expand_dims(logits, axis=1)
with tf.variable_scope("value"):
x = flat_observations
for size in self.hparams.value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)
x = tf.layers.dense(x, size, activation=tf.nn.relu)
value = tf.layers.dense(x, 1)
logits = clip_logits(logits, self.hparams)
return {"target_policy": logits, "target_value": value}

Expand All @@ -478,23 +476,22 @@ def body(self, features):
dropout = getattr(self.hparams, "dropout_ppo", 0.0)
with tf.variable_scope("feed_forward_cnn_small"):
x = tf.cast(x, tf.float32) / 255.0
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2],
activation_fn=tf.nn.relu, padding="SAME")
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2],
activation_fn=tf.nn.relu, padding="SAME")
x = tf.layers.conv2d(x, 32, (5, 5), strides=(2, 2),
activation=tf.nn.relu, padding="same")
x = tf.layers.conv2d(x, 32, (5, 5), strides=(2, 2),
activation=tf.nn.relu, padding="same")

flat_x = tf.layers.flatten(x)
flat_x = tf.layers.dropout(flat_x, rate=dropout)
x = tf.contrib.layers.fully_connected(flat_x, 128, tf.nn.relu)
x = tf.layers.dense(flat_x, 128, activation=tf.nn.relu)

logits = tf.layers.dense(
x, self.hparams.problem.num_actions, name="dense2"
)
logits = clip_logits(logits, self.hparams)
logits = tf.expand_dims(logits, axis=1)

value = tf.contrib.layers.fully_connected(
x, 1, activation_fn=None)
value = tf.layers.dense(x, 1)
return {"target_policy": logits, "target_value": value}


Expand Down Expand Up @@ -547,15 +544,12 @@ def body(self, features):
with tf.variable_scope("dense_bitwise"):
flat_x = discretization.int_to_bit_embed(flat_x, 8, 32)

x = tf.contrib.layers.fully_connected(flat_x, 256, tf.nn.relu)
x = tf.contrib.layers.fully_connected(flat_x, 128, tf.nn.relu)
x = tf.layers.dense(flat_x, 256, activation=tf.nn.relu)
x = tf.layers.dense(flat_x, 128, activation=tf.nn.relu)

logits = tf.contrib.layers.fully_connected(
x, self.hparams.problem.num_actions, activation_fn=None
)
logits = tf.layers.dense(x, self.hparams.problem.num_actions)

value = tf.contrib.layers.fully_connected(
x, 1, activation_fn=None)[..., 0]
value = tf.layers.dense(x, 1)[..., 0]

return {"target_policy": logits, "target_value": value}

Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/models/research/transformer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def __init__(self, *args, **kwargs):
self._hparams.num_residuals, self._hparams.num_blocks,
self._hparams.hidden_size, block_dim
],
initializer=tf.contrib.layers.xavier_initializer(),
initializer=tf.initializers.glorot_uniform(),
trainable=self._hparams.trainable_projections)

self._hparams.bottleneck = functools.partial(
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/models/revnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def wrapped_partial(fn, *args, **kwargs):
return wrapped


conv_initializer = tf.contrib.layers.variance_scaling_initializer(
factor=2.0, mode='FAN_OUT')
conv_initializer = tf.initializers.variance_scaling(
scale=2.0, mode='fan_out')

CONFIG = {'2d': {'conv': wrapped_partial(
tf.layers.conv2d, kernel_initializer=conv_initializer),
Expand Down
4 changes: 0 additions & 4 deletions tensor2tensor/models/video/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@
import tensorflow as tf


tfl = tf.layers
tfcl = tf.contrib.layers


def flat_lists(list_of_lists):
return [x for l in list_of_lists for x in l]

Expand Down
4 changes: 0 additions & 4 deletions tensor2tensor/models/video/basic_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@
import tensorflow as tf


tfl = tf.layers
tfcl = tf.contrib.layers


@registry.register_model
class NextFrameBasicDeterministic(base.NextFrameBase):
"""Basic next-frame model, may take actions and predict rewards too."""
Expand Down
6 changes: 0 additions & 6 deletions tensor2tensor/models/video/basic_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@
from tensor2tensor.models.video import basic_stochastic
from tensor2tensor.utils import registry

import tensorflow as tf


tfl = tf.layers
tfcl = tf.contrib.layers


@registry.register_model
class NextFrameBasicRecurrent(
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/models/video/emily.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def construct_model(self, images, actions, rewards):
for i, image in enumerate(images):
with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
enc, skips = self.encoder(image, rnn_size, has_batchnorm=has_batchnorm)
enc = tfcl.flatten(enc)
enc = tfl.flatten(enc)
enc_images.append(enc)
enc_skips.append(skips)

Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/models/video/sv2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def construct_predictive_tower(

if self.hparams.model_options == "CDNA":
# cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
cdna_input = tfcl.flatten(hidden5)
cdna_input = tfl.flatten(hidden5)
transformed += common_video.cdna_transformation(
input_image, cdna_input, num_masks, int(color_channels),
self.hparams.dna_kernel_size, self.hparams.relu_shift)
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/models/video/svg_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def construct_model(self, images, actions, rewards):
for i, image in enumerate(images):
with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
enc, skips = self.encoder(image, g_dim, has_batchnorm=has_batchnorm)
enc = tfcl.flatten(enc)
enc = tfl.flatten(enc)
enc_images.append(enc)
enc_skips.append(skips)

Expand All @@ -199,7 +199,7 @@ def construct_model(self, images, actions, rewards):
h_current = enc_images[i-1]
else:
h_current, _ = self.encoder(gen_images[-1], g_dim)
h_current = tfcl.flatten(h_current)
h_current = tfl.flatten(h_current)

# target encoding
h_target = enc_images[i]
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,6 @@ def get_variable_initializer(hparams):
return tf.variance_scaling_initializer(
hparams.initializer_gain, mode="fan_avg", distribution="uniform")
elif hparams.initializer == "xavier":
return tf.contrib.layers.xavier_initializer()
return tf.initializers.glorot_uniform()
else:
raise ValueError("Unrecognized initializer: %s" % hparams.initializer)