Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to the "English-to-Spanish Translation with a Sequence-to-Sequence Transformer" Code Example #1997

Merged
merged 5 commits into from
Nov 29, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions examples/nlp/neural_machine_translation_with_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,7 @@ def call(self, inputs):
return embedded_tokens + embedded_positions

def compute_mask(self, inputs, mask=None):
if mask is None:
return None
else:
return ops.not_equal(inputs, 0)
return ops.not_equal(inputs, 0)

def get_config(self):
config = super().get_config()
Expand Down Expand Up @@ -342,24 +339,30 @@ def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
self.layernorm_3 = layers.LayerNormalization()
self.supports_masking = True

def call(self, inputs, encoder_outputs, mask=None):
def call(self, inputs, mask=None):
inputs, encoder_outputs = inputs
causal_mask = self.get_causal_attention_mask(inputs)
if mask is not None:
padding_mask = ops.cast(mask[:, None, :], dtype="int32")
padding_mask = ops.minimum(padding_mask, causal_mask)

if mask is None:
inputs_padding_mask, encoder_outputs_padding_mask = None, None
else:
padding_mask = None
inputs_padding_mask, encoder_outputs_padding_mask = mask

attention_output_1 = self.attention_1(
query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
query=inputs,
value=inputs,
key=inputs,
attention_mask=causal_mask,
query_mask=inputs_padding_mask,
)
out_1 = self.layernorm_1(inputs + attention_output_1)

attention_output_2 = self.attention_2(
query=out_1,
value=encoder_outputs,
key=encoder_outputs,
attention_mask=padding_mask,
query_mask=inputs_padding_mask,
key_mask=encoder_outputs_padding_mask,
)
out_2 = self.layernorm_2(out_1 + attention_output_2)

Expand Down Expand Up @@ -407,14 +410,15 @@ def get_config(self):
decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)
x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])
x = layers.Dropout(0.5)(x)
decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)

decoder_outputs = decoder([decoder_inputs, encoder_outputs])
transformer = keras.Model(
[encoder_inputs, decoder_inputs], decoder_outputs, name="transformer"
{"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs},
decoder_outputs,
name="transformer",
)

"""
Expand All @@ -431,7 +435,9 @@ def get_config(self):

transformer.summary()
transformer.compile(
"rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
"rmsprop",
loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),
metrics=["accuracy"],
)
transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)

Expand All @@ -454,7 +460,12 @@ def decode_sequence(input_sentence):
decoded_sentence = "[start]"
for i in range(max_decoded_sentence_length):
tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]
predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])
predictions = transformer(
{
"encoder_inputs": tokenized_input_sentence,
"decoder_inputs": tokenized_target_sentence,
}
)

# ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
sampled_token_index = ops.convert_to_numpy(
Expand Down