Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f1076df

Browse files
lgeigerCopybara-Service
authored and
Copybara-Service
committed
Merge of PR #1446
PiperOrigin-RevId: 233438134
1 parent 8919e82 commit f1076df

File tree

5 files changed

+15
-18
lines changed

5 files changed

+15
-18
lines changed

tensor2tensor/data_generators/cipher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,9 @@ def encipher_vigenere(plaintext, plain_vocab, key):
213213
"""
214214
ciphertext = []
215215
# generate Vigenere table
216-
layers = []
217-
for i in range(len(plain_vocab)):
218-
layers.append(ShiftEncryptionLayer(plain_vocab, i))
216+
layers = [
217+
ShiftEncryptionLayer(plain_vocab, i) for i in range(len(plain_vocab))
218+
]
219219

220220
for i, sentence in enumerate(plaintext):
221221
cipher_sentence = []

tensor2tensor/data_generators/wiki_revision_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -491,10 +491,8 @@ def edit_distance_filter(source_target_input, max_equal_to_diff_ratio=0):
491491
if not max_equal_to_diff_ratio:
492492
return source_target_input, thrown_out_count
493493

494-
for i in range(len(source_target_input)):
495-
src = source_target_input[i][0]
496-
tgt = source_target_input[i][1]
497-
opcodes = fast_match_sequences(src, tgt)
494+
for src_tgt in source_target_input:
495+
opcodes = fast_match_sequences(*src_tgt)
498496
diff_char_count = 0
499497
equal_char_count = 0
500498
for tag, i1, i2, j1, j2 in opcodes:
@@ -504,7 +502,7 @@ def edit_distance_filter(source_target_input, max_equal_to_diff_ratio=0):
504502
else:
505503
equal_char_count += i2 - i1
506504
if diff_char_count <= max_equal_to_diff_ratio * equal_char_count:
507-
source_target_output.append(source_target_input[i])
505+
source_target_output.append(src_tgt)
508506
else:
509507
thrown_out_count += 1
510508
return source_target_output, thrown_out_count

tensor2tensor/layers/common_layers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,10 +2334,10 @@ def ravanbakhsh_set_layer(layer_size,
23342334

23352335
def fn_device_dependency_dict():
23362336
"""State container for fn_device_dependency."""
2337-
if not hasattr(tf.get_default_graph(), "dependency_dict"):
2338-
setattr(tf.get_default_graph(), "dependency_dict",
2339-
collections.defaultdict(list))
2340-
return tf.get_default_graph().dependency_dict
2337+
default_graph = tf.get_default_graph()
2338+
if not hasattr(default_graph, "dependency_dict"):
2339+
default_graph.dependency_dict = collections.defaultdict(list)
2340+
return default_graph.dependency_dict
23412341

23422342

23432343
@contextlib.contextmanager
@@ -2791,8 +2791,7 @@ def shape_list(x):
27912791
shape = tf.shape(x)
27922792

27932793
ret = []
2794-
for i in range(len(static)):
2795-
dim = static[i]
2794+
for i, dim in enumerate(static):
27962795
if dim is None:
27972796
dim = shape[i]
27982797
ret.append(dim)

tensor2tensor/rl/trainer_model_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def train_supervised(problem, model_name, hparams, data_dir, output_dir,
126126
schedule="continuous_train_and_eval"):
127127
"""Train supervised."""
128128
if local_eval_frequency is None:
129-
local_eval_frequency = getattr(FLAGS, "local_eval_frequency")
129+
local_eval_frequency = FLAGS.local_eval_frequency
130130

131131
exp_fn = trainer_lib.create_experiment_fn(
132132
model_name, problem, data_dir, train_steps, eval_steps,

tensor2tensor/serving/serving_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ def _make_grpc_request(examples):
115115
scores = tf.make_ndarray(response.outputs["scores"])
116116
assert len(outputs) == len(scores)
117117
return [{
118-
"outputs": outputs[i],
119-
"scores": scores[i]
120-
} for i in range(len(outputs))]
118+
"outputs": output,
119+
"scores": score
120+
} for output, score in zip(outputs, scores)]
121121

122122
return _make_grpc_request
123123

0 commit comments

Comments
 (0)