-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AMP for TF Albert #10141
Add AMP for TF Albert #10141
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks for fixing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! Could you split the two objectives of cleaning the source code and AMP in two PRs? It's better because:
- It will be easier to detect the issue if we identify an issue with this commit down the road
- It makes it easier to read. I can't see what changes are done here related to AMP
@@ -1009,7 +1008,8 @@ def call( | |||
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score)) | |||
|
|||
if not inputs["return_dict"]: | |||
return (prediction_scores, seq_relationship_score) + outputs[2:] | |||
output = (prediction_scores, seq_relationship_score) + outputs[2:] | |||
return ((total_loss,) + output) if total_loss is not None else output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was not returning the loss before that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should have been caught by a test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test on the loss computation was working because the test is run only for return_dict=True
, it would have failed if return_dict=False
.
# special case for ForPreTraining model | ||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): | ||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) | ||
|
||
if return_labels: | ||
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values(): | ||
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) | ||
|
||
return inputs_dict | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it will get caught by a test from now on :)
# scale attention_scores | ||
dk = tf.cast(shape_list(key_layer)[-1], tf.float32) | ||
attention_scores = attention_scores / tf.math.sqrt(dk) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First fix for AMP, (removing the cast into tf.float32
)
@@ -550,9 +601,10 @@ def call( | |||
# positions we want to attend and -10000.0 for masked positions. | |||
# Since we are adding it to the raw scores before the softmax, this is | |||
# effectively the same as removing these entirely. | |||
|
|||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second fix for AMP, still remove the cast into tf.float32
I can split this PR into two different ones, but the one on AMP will be very short (only two single line to update, see the review above). Are you agree with a that tiny PR? If it is still ok, I will split this one^^ |
It's ok, thanks for showing me the changes! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes LGTM
@patrickvonplaten feel free to merge if it looks ok for you! |
What does this PR do?
This PR adds the following features to TF Albert: