Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Reduce usage of tf.contrib #1345

Merged
merged 9 commits into from
Jan 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensor2tensor/bin/t2t_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def main(argv):
sur_ch_model.get_probs(inputs)

checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir)
tf.contrib.framework.init_from_checkpoint(
tf.train.init_from_checkpoint(
tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"})
sess.run(tf.global_variables_initializer())

Expand Down
6 changes: 3 additions & 3 deletions tensor2tensor/bin/t2t_avg_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main(_):
for model in bleu_hook.stepfiles_iterator(model_dir, FLAGS.wait_minutes,
FLAGS.min_steps):
if models_processed == 0:
var_list = tf.contrib.framework.list_variables(model.filename)
var_list = tf.train.list_variables(model.filename)
avg_values = {}
for (name, shape) in var_list:
if not (name.startswith("global_step") or
Expand All @@ -69,7 +69,7 @@ def main(_):
models_processed += 1

tf.logging.info("Loading [%d]: %s" % (models_processed, model.filename))
reader = tf.contrib.framework.load_checkpoint(model.filename)
reader = tf.train.load_checkpoint(model.filename)
for name in avg_values:
avg_values[name] += reader.get_tensor(name) / FLAGS.n
queue.append(model)
Expand Down Expand Up @@ -106,7 +106,7 @@ def main(_):
tf.reset_default_graph()
first_model = queue.popleft()

reader = tf.contrib.framework.load_checkpoint(first_model.filename)
reader = tf.train.load_checkpoint(first_model.filename)
for name in avg_values:
avg_values[name] -= reader.get_tensor(name) / FLAGS.n

Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):

def encode_images_as_png(images):
"""Yield images encoded as pngs."""
if tf.contrib.eager.in_eager_mode():
if tf.executing_eagerly():
for image in images:
yield tf.image.encode_png(image).numpy()
else:
Expand Down
8 changes: 4 additions & 4 deletions tensor2tensor/data_generators/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def _preprocess(example):

if interleave:
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
tf.data.experimental.parallel_interleave(
_preprocess, sloppy=True, cycle_length=8))
else:
dataset = dataset.flat_map(_preprocess)
Expand Down Expand Up @@ -674,7 +674,7 @@ def _load_records_and_preprocess(filenames):
# Create data-set from files by parsing, pre-processing and interleaving.
if shuffle_files:
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
tf.data.experimental.parallel_interleave(
_load_records_and_preprocess, sloppy=True, cycle_length=8))
else:
dataset = _load_records_and_preprocess(dataset)
Expand Down Expand Up @@ -963,7 +963,7 @@ def define_shapes(example):
batching_scheme["batch_sizes"] = [hparams.batch_size]
batching_scheme["boundaries"] = []
dataset = dataset.apply(
tf.contrib.data.bucket_by_sequence_length(
tf.data.experimental.bucket_by_sequence_length(
data_reader.example_length, batching_scheme["boundaries"],
batching_scheme["batch_sizes"]))

Expand Down Expand Up @@ -1040,7 +1040,7 @@ def serving_input_fn(self, hparams):
tf.shape(serialized_example, out_type=tf.int64)[0],
dataset.output_shapes)
dataset = dataset.map(standardize_shapes)
features = tf.contrib.data.get_single_element(dataset)
features = tf.data.experimental.get_single_element(dataset)

if self.has_inputs:
features.pop("targets", None)
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/layers/common_image_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def postprocess_image(x, rows, cols, hparams):
use_bias=True,
activation=None,
name="output_conv")
if (hparams.mode == tf.contrib.learn.ModeKeys.INFER and
if (hparams.mode == tf.estimator.ModeKeys.PREDICT and
hparams.block_raster_scan):
y = targets
yshape = common_layers.shape_list(y)
Expand Down Expand Up @@ -547,7 +547,7 @@ def prepare_decoder(targets, hparams):

# during training, images are [batch, IMG_LEN, IMG_LEN, 3].
# At inference, they are [batch, curr_infer_length, 1, 1]
if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
if hparams.mode == tf.estimator.ModeKeys.PREDICT:
curr_infer_length = targets_shape[1]
if hparams.block_raster_scan:
assert hparams.img_len*channels % hparams.query_shape[1] == 0
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/layers/common_image_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def testPostProcessImageInferMode(self, likelihood, num_mixtures, depth):
block_raster_scan=True,
hidden_size=2,
likelihood=likelihood,
mode=tf.contrib.learn.ModeKeys.INFER,
mode=tf.estimator.ModeKeys.PREDICT,
num_mixtures=num_mixtures,
query_shape=[block_length, block_width],
)
Expand Down
8 changes: 4 additions & 4 deletions tensor2tensor/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def embedding(x,
# On the backwards pass, we want to convert the gradient from
# an indexed-slices to a regular tensor before sending it back to the
# parameter server. This avoids excess computation on the parameter server.
if not tf.contrib.eager.in_eager_mode():
if not tf.executing_eagerly():
embedding_var = convert_gradient_to_tensor(embedding_var)
x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate)
emb_x = gather(embedding_var, x, dtype)
Expand Down Expand Up @@ -2868,7 +2868,7 @@ def ones_matrix_band_part(rows, cols, num_lower, num_upper, out_shape=None):
def reshape_like_all_dims(a, b):
"""Reshapes a to match the shape of b."""
ret = tf.reshape(a, tf.shape(b))
if not tf.contrib.eager.in_eager_mode():
if not tf.executing_eagerly():
ret.set_shape(b.get_shape())
return ret

Expand Down Expand Up @@ -3193,7 +3193,7 @@ def should_generate_summaries():
def reshape_like(a, b):
"""Reshapes a to match the shape of b in all but the last dimension."""
ret = tf.reshape(a, tf.concat([tf.shape(b)[:-1], tf.shape(a)[-1:]], 0))
if not tf.contrib.eager.in_eager_mode():
if not tf.executing_eagerly():
ret.set_shape(b.get_shape().as_list()[:-1] + a.get_shape().as_list()[-1:])
return ret

Expand All @@ -3205,7 +3205,7 @@ def summarize_video(video, prefix, max_outputs=1):
raise ValueError("Assuming videos given as tensors in the format "
"[batch, time, height, width, channels] but got one "
"of shape: %s" % str(video_shape))
if tf.contrib.eager.in_eager_mode():
if tf.executing_eagerly():
return
if video.get_shape().as_list()[1] is None:
tf.summary.image(
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/layers/common_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def lstm_cell(inputs,
name=None):
"""Full LSTM cell."""
input_shape = common_layers.shape_list(inputs)
cell = tf.contrib.rnn.LSTMCell(num_units,
cell = tf.nn.rnn_cell.LSTMCell(num_units,
use_peepholes=use_peepholes,
cell_clip=cell_clip,
initializer=initializer,
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/layers/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def gumbel_softmax(x,
d_dev = -tf.reduce_mean(d_variance)
ret = s

if mode != tf.contrib.learn.ModeKeys.TRAIN:
if mode != tf.estimator.ModeKeys.TRAIN:
ret = tf.reshape(maxvhot, common_layers.shape_list(s)) # Just hot @eval.
return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002

Expand Down Expand Up @@ -822,7 +822,7 @@ def predict_bits_with_lstm(prediction_source, state_size, total_num_bits,

with tf.variable_scope("predict_bits_with_lstm"):
# Layers and cell state creation.
lstm_cell = tf.contrib.rnn.LSTMCell(state_size)
lstm_cell = tf.nn.rnn_cell.LSTMCell(state_size)
discrete_predict = tf.layers.Dense(2**bits_at_once, name="discrete_predict")
discrete_embed = tf.layers.Dense(state_size, name="discrete_embed")
batch_size = common_layers.shape_list(prediction_source)[0]
Expand Down
6 changes: 3 additions & 3 deletions tensor2tensor/layers/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _get_weights(self, hidden_dim=None):
else:
ret = tf.concat(shards, 0)
# Convert ret to tensor.
if not tf.contrib.eager.in_eager_mode():
if not tf.executing_eagerly():
ret = common_layers.convert_gradient_to_tensor(ret)
return ret

Expand Down Expand Up @@ -226,15 +226,15 @@ class ImageModality(modality.Modality):

def bottom(self, x):
with tf.variable_scope(self.name):
if not tf.contrib.eager.in_eager_mode():
if not tf.executing_eagerly():
tf.summary.image(
"inputs", common_layers.tpu_safe_image_summary(x), max_outputs=2)
return tf.to_float(x)

def targets_bottom(self, x):
inputs = x
with tf.variable_scope(self.name):
if not tf.contrib.eager.in_eager_mode():
if not tf.executing_eagerly():
tf.summary.image(
"targets_bottom",
common_layers.tpu_safe_image_summary(inputs),
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/models/image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def body(self, features):
"must be ImageChannelBottomIdentityModality and "
"num_channels must be 1.")
if (not tf.get_variable_scope().reuse and
hparams.mode != tf.contrib.learn.ModeKeys.INFER and
hparams.mode != tf.estimator.ModeKeys.PREDICT and
hparams.modality["targets"] !=
modalities.ImageChannelBottomIdentityModality):
tf.summary.image("targets", tf.to_float(targets), max_outputs=1)
Expand Down
6 changes: 3 additions & 3 deletions tensor2tensor/models/image_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def body(self, features):
targets = features["targets"]
targets_shape = common_layers.shape_list(targets)
if not (tf.get_variable_scope().reuse or
hparams.mode == tf.contrib.learn.ModeKeys.INFER):
hparams.mode == tf.estimator.ModeKeys.PREDICT):
tf.summary.image("targets", targets, max_outputs=1)

decoder_input, rows, cols = cia.prepare_decoder(
Expand Down Expand Up @@ -76,7 +76,7 @@ def body(self, features):
targets = features["targets"]
inputs = features["inputs"]
if not (tf.get_variable_scope().reuse or
hparams.mode == tf.contrib.learn.ModeKeys.INFER):
hparams.mode == tf.estimator.ModeKeys.PREDICT):
tf.summary.image("inputs", inputs, max_outputs=1)
tf.summary.image("targets", targets, max_outputs=1)

Expand Down Expand Up @@ -112,7 +112,7 @@ def body(self, features):
targets = features["targets"]
inputs = features["inputs"]
if not (tf.get_variable_scope().reuse or
hparams.mode == tf.contrib.learn.ModeKeys.INFER):
hparams.mode == tf.estimator.ModeKeys.PREDICT):
tf.summary.image("inputs", inputs, max_outputs=1)
tf.summary.image("targets", targets, max_outputs=1)

Expand Down
14 changes: 7 additions & 7 deletions tensor2tensor/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@


def _dropout_lstm_cell(hparams, train):
return tf.contrib.rnn.DropoutWrapper(
tf.contrib.rnn.LSTMCell(hparams.hidden_size),
return tf.nn.rnn_cell.DropoutWrapper(
tf.nn.rnn_cell.LSTMCell(hparams.hidden_size),
input_keep_prob=1.0 - hparams.dropout * tf.to_float(train))


Expand Down Expand Up @@ -58,7 +58,7 @@ def lstm(inputs, sequence_length, hparams, train, name, initial_state=None):
for _ in range(hparams.num_hidden_layers)]
with tf.variable_scope(name):
return tf.nn.dynamic_rnn(
tf.contrib.rnn.MultiRNNCell(layers),
tf.nn.rnn_cell.MultiRNNCell(layers),
inputs,
sequence_length,
initial_state=initial_state,
Expand Down Expand Up @@ -192,11 +192,11 @@ def lstm_bid_encoder(inputs, sequence_length, hparams, train, name):
"""Bidirectional LSTM for encoding inputs that are [batch x time x size]."""

with tf.variable_scope(name):
cell_fw = tf.contrib.rnn.MultiRNNCell(
cell_fw = tf.nn.rnn_cell.MultiRNNCell(
[_dropout_lstm_cell(hparams, train)
for _ in range(hparams.num_hidden_layers)])

cell_bw = tf.contrib.rnn.MultiRNNCell(
cell_bw = tf.nn.rnn_cell.MultiRNNCell(
[_dropout_lstm_cell(hparams, train)
for _ in range(hparams.num_hidden_layers)])

Expand All @@ -213,7 +213,7 @@ def lstm_bid_encoder(inputs, sequence_length, hparams, train, name):
encoder_states = []

for i in range(hparams.num_hidden_layers):
if isinstance(encoder_fw_state[i], tf.contrib.rnn.LSTMStateTuple):
if isinstance(encoder_fw_state[i], tf.nn.rnn_cell.LSTMStateTuple):
encoder_state_c = tf.concat(
values=(encoder_fw_state[i].c, encoder_bw_state[i].c),
axis=1,
Expand All @@ -222,7 +222,7 @@ def lstm_bid_encoder(inputs, sequence_length, hparams, train, name):
values=(encoder_fw_state[i].h, encoder_bw_state[i].h),
axis=1,
name="encoder_fw_state_h")
encoder_state = tf.contrib.rnn.LSTMStateTuple(
encoder_state = tf.nn.rnn_cell.LSTMStateTuple(
c=encoder_state_c, h=encoder_state_h)
elif isinstance(encoder_fw_state[i], tf.Tensor):
encoder_state = tf.concat(
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/models/research/glow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ def compute_prior(name, z, latent, hparams, condition=False, state=None,
The first-three dimensions of the latent should be the same as z.
hparams: next_frame_glow_hparams.
condition: Whether or not to condition the distribution on latent.
state: tf.contrib.rnn.LSTMStateTuple.
state: tf.nn.rnn_cell.LSTMStateTuple.
the current state of a LSTM used to model the distribution. Used
only if hparams.latent_dist_encoder = "conv_lstm".
temperature: float, temperature with which to sample from the Gaussian.
Expand Down Expand Up @@ -1025,7 +1025,7 @@ def split(name, x, reverse=False, eps=None, eps_std=None, cond_latents=None,
eps_std: Sample x2 with the provided eps_std.
cond_latents: optionally condition x2 on cond_latents.
hparams: next_frame_glow hparams.
state: tf.contrib.rnn.LSTMStateTuple. Current state of the LSTM over z_2.
state: tf.nn.rnn_cell.LSTMStateTuple. Current state of the LSTM over z_2.
Used only when hparams.latent_dist_encoder == "conv_lstm"
condition: bool, Whether or not to condition the distribution on
cond_latents.
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/models/research/glow_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def test_latent_dist_encoder(self, encoder="conv_lstm", skip=True,
state_t = tf.convert_to_tensor(state_rand)
if encoder in ["conv_net", "conv3d_net"]:
latent_t = [latent_t, latent_t]
init_state = tf.contrib.rnn.LSTMStateTuple(state_t, state_t)
init_state = tf.nn.rnn_cell.LSTMStateTuple(state_t, state_t)
hparams = self.get_glow_hparams()
hparams.latent_dist_encoder = encoder
hparams.latent_skip = skip
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/models/research/transformer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def top_k_softmax(x, k):
def top_k_experts(x, k, hparams):
x_shape = common_layers.shape_list(x)
x_flat = tf.reshape(x, [-1, common_layers.shape_list(x)[-1]])
is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
gates, load = expert_utils.noisy_top_k_gating(
x_flat, 2 ** hparams.z_size, is_training, k)
gates_shape = [x_shape[0], x_shape[1], x_shape[2], 2 ** hparams.z_size]
Expand Down
6 changes: 3 additions & 3 deletions tensor2tensor/models/research/vqa_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ def image_encoder(image_feat,

def _get_rnn_cell(hparams):
if hparams.rnn_type == "lstm":
rnn_cell = tf.contrib.rnn.BasicLSTMCell
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell
elif hparams.rnn_type == "lstm_layernorm":
rnn_cell = tf.contrib.rnn.LayerNormBasicLSTMCell
return tf.contrib.rnn.DropoutWrapper(
return tf.nn.rnn_cell.DropoutWrapper(
rnn_cell(hparams.hidden_size),
output_keep_prob=1.0-hparams.dropout)

Expand Down Expand Up @@ -269,7 +269,7 @@ def question_encoder(question, hparams, name="encoder"):

# rnn_layers = [_get_rnn_cell(hparams)
# for _ in range(hparams.num_rnn_layers)]
# rnn_multi_cell = tf.contrib.rnn.MultiRNNCell(rnn_layers)
# rnn_multi_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
rnn_cell = _get_rnn_cell(hparams)
# outputs, _ = tf.nn.dynamic_rnn(
# rnn_cell, question, length, dtype=tf.float32)
Expand Down
6 changes: 3 additions & 3 deletions tensor2tensor/models/shake_shake.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def shake_shake_skip_connection(x, output_filters, stride, is_training):
def shake_shake_branch(x, output_filters, stride, rand_forward, rand_backward,
hparams):
"""Building a 2 branching convnet."""
is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
x = tf.nn.relu(x)
x = tf.layers.conv2d(
x,
Expand All @@ -76,7 +76,7 @@ def shake_shake_branch(x, output_filters, stride, rand_forward, rand_backward,

def shake_shake_block(x, output_filters, stride, hparams):
"""Builds a full shake-shake sub layer."""
is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
batch_size = common_layers.shape_list(x)[0]

# Generate random numbers for scaling the branches.
Expand Down Expand Up @@ -138,7 +138,7 @@ class ShakeShake(t2t_model.T2TModel):

def body(self, features):
hparams = self._hparams
is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
inputs = features["inputs"]
assert (hparams.num_hidden_layers - 2) % 6 == 0
assert hparams.hidden_size % 16 == 0
Expand Down
8 changes: 4 additions & 4 deletions tensor2tensor/models/video/svg_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ def rnn_model(self, hidden_size, nlayers, rnn_type, name):
"""
layers_units = [hidden_size] * nlayers
if rnn_type == "lstm":
rnn_cell = tf.contrib.rnn.LSTMCell
rnn_cell = tf.nn.rnn_cell.LSTMCell
elif rnn_type == "gru":
rnn_cell = tf.contrib.rnn.GRUCell
rnn_cell = tf.nn.rnn_cell.GRUCell
else:
rnn_cell = tf.contrib.rnn.RNNCell
rnn_cell = tf.nn.rnn_cell.RNNCell
cells = [rnn_cell(units, name=name) for units in layers_units]
stacked_rnn = tf.contrib.rnn.MultiRNNCell(cells)
stacked_rnn = tf.nn.rnn_cell.MultiRNNCell(cells)
return stacked_rnn

def deterministic_rnn(self, cell, inputs, states, output_size, scope):
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/notebooks/asr_transformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@
"from tensor2tensor.utils import metrics\n",
"\n",
"# Enable TF Eager execution\n",
"from tensorflow.contrib.eager.python import tfe\n",
"tfe.enable_eager_execution()\n",
"tfe = tf.contrib.eager\n",
"tf.enable_eager_execution()\n",
"\n",
"# Other setup\n",
"Modes = tf.estimator.ModeKeys\n",
Expand Down
Loading