Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AMP for TF Albert #10141

Merged
merged 1 commit into from
Feb 15, 2021
Merged

Add AMP for TF Albert #10141

merged 1 commit into from
Feb 15, 2021

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Feb 11, 2021

What does this PR do?

This PR adds the following features to TF Albert:

  • AMP compliancy
  • Loss computation for TFAlbertForPreTraining
  • Cleaning source code

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for fixing!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Could you split the two objectives of cleaning the source code and AMP in two PRs? It's better because:

  • It will be easier to detect the issue if we identify an issue with this commit down the road
  • It makes it easier to read. I can't see what changes are done here related to AMP

@@ -1009,7 +1008,8 @@ def call(
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))

if not inputs["return_dict"]:
return (prediction_scores, seq_relationship_score) + outputs[2:]
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was not returning the loss before that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have been caught by a test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test on the loss computation was working because the test is run only for return_dict=True, it would have failed if return_dict=False.

Comment on lines +246 to +255
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

if return_labels:
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)

return inputs_dict

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it will get caught by a test from now on :)

Comment on lines -221 to -223
# scale attention_scores
dk = tf.cast(shape_list(key_layer)[-1], tf.float32)
attention_scores = attention_scores / tf.math.sqrt(dk)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First fix for AMP, (removing the cast into tf.float32)

@@ -550,9 +601,10 @@ def call(
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.

extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second fix for AMP, still remove the cast into tf.float32

@jplu
Copy link
Contributor Author

jplu commented Feb 14, 2021

I can split this PR into two different ones, but the one on AMP will be very short (only two single line to update, see the review above). Are you agree with a that tiny PR? If it is still ok, I will split this one^^

@LysandreJik
Copy link
Member

It's ok, thanks for showing me the changes!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes LGTM

@jplu
Copy link
Contributor Author

jplu commented Feb 15, 2021

@patrickvonplaten feel free to merge if it looks ok for you!

@jplu jplu merged commit 31b0560 into huggingface:master Feb 15, 2021
@jplu jplu deleted the tf-albert-amp branch February 15, 2021 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants