Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 01d76fa

Browse files
aeloyqCopybara-Service
authored and
Copybara-Service
committed
internal merge of PR #1303
PiperOrigin-RevId: 227927931
1 parent 98ec1ee commit 01d76fa

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1584,7 +1584,8 @@ def dot_product_attention_relative(q,
15841584
raise ValueError("Max relative position (%s) should be > 0 when using "
15851585
"relative self attention." % (max_relative_position))
15861586
with tf.variable_scope(
1587-
name, default_name="dot_product_attention_relative", values=[q, k, v]) as scope:
1587+
name, default_name="dot_product_attention_relative",
1588+
values=[q, k, v]) as scope:
15881589

15891590
# This calculation only works for self attention.
15901591
# q, k and v must therefore have the same shape.

tensor2tensor/visualization/visualization.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,19 @@ def __init__(
5151

5252
def encode(self, input_str):
5353
"""Input str to features dict, ready for inference."""
54-
inputs = self.encoders['inputs'].encode(input_str) + [EOS_ID]
54+
inputs = self.encoders["inputs"].encode(input_str) + [EOS_ID]
5555
batch_inputs = np.reshape(inputs, [1, -1, 1, 1]) # Make it 3D.
5656
return batch_inputs
5757

5858
def decode(self, integers):
5959
"""List of ints to str."""
6060
integers = list(np.squeeze(integers))
61-
return self.encoders['inputs'].decode(integers)
61+
return self.encoders["inputs"].decode(integers)
6262

6363
def decode_list(self, integers):
6464
"""List of ints to list of str."""
6565
integers = list(np.squeeze(integers))
66-
return self.encoders['inputs'].decode_list(integers)
66+
return self.encoders["inputs"].decode_list(integers)
6767

6868
def get_vis_data_from_string(self, sess, input_string):
6969
"""Constructs the data needed for visualizing attentions.
@@ -135,11 +135,11 @@ def build_model(hparams_set, model_name, data_dir, problem_name, beam_size=1):
135135
translate_model = registry.model(model_name)(
136136
hparams, tf.estimator.ModeKeys.EVAL)
137137

138-
inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name='inputs')
139-
targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name='targets')
138+
inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs")
139+
targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="targets")
140140
translate_model({
141-
'inputs': inputs,
142-
'targets': targets,
141+
"inputs": inputs,
142+
"targets": targets,
143143
})
144144

145145
# Must be called after building the training graph, so that the dict will
@@ -150,8 +150,8 @@ def build_model(hparams_set, model_name, data_dir, problem_name, beam_size=1):
150150

151151
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
152152
samples = translate_model.infer({
153-
'inputs': inputs,
154-
}, beam_size=beam_size)['outputs']
153+
"inputs": inputs,
154+
}, beam_size=beam_size)["outputs"]
155155

156156
return inputs, targets, samples, att_mats
157157

@@ -182,19 +182,22 @@ def get_att_mats(translate_model):
182182
dec_atts = []
183183
encdec_atts = []
184184

185-
prefix = 'transformer/body/'
186-
postfix_self_attention = '/multihead_attention/dot_product_attention'
185+
prefix = "transformer/body/"
186+
postfix_self_attention = "/multihead_attention/dot_product_attention"
187187
if translate_model.hparams.self_attention_type == "dot_product_relative":
188-
postfix_self_attention = '/multihead_attention/dot_product_attention_relative'
189-
postfix_encdec = '/multihead_attention/dot_product_attention'
188+
postfix_self_attention = ("/multihead_attention/"
189+
"dot_product_attention_relative")
190+
postfix_encdec = "/multihead_attention/dot_product_attention"
190191

191192
for i in range(translate_model.hparams.num_hidden_layers):
192193
enc_att = translate_model.attention_weights[
193-
'%sencoder/layer_%i/self_attention%s' % (prefix, i, postfix_self_attention)]
194+
"%sencoder/layer_%i/self_attention%s"
195+
% (prefix, i, postfix_self_attention)]
194196
dec_att = translate_model.attention_weights[
195-
'%sdecoder/layer_%i/self_attention%s' % (prefix, i, postfix_self_attention)]
197+
"%sdecoder/layer_%i/self_attention%s"
198+
% (prefix, i, postfix_self_attention)]
196199
encdec_att = translate_model.attention_weights[
197-
'%sdecoder/layer_%i/encdec_attention%s' % (prefix, i, postfix_encdec)]
200+
"%sdecoder/layer_%i/encdec_attention%s" % (prefix, i, postfix_encdec)]
198201
enc_atts.append(enc_att)
199202
dec_atts.append(dec_att)
200203
encdec_atts.append(encdec_att)

0 commit comments

Comments
 (0)