-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feed forward chunking #6024
Feed forward chunking #6024
Changes from 13 commits
62788a7
b413437
fcaa3aa
623352f
5a26a2d
c96bc33
b85adeb
c0fdd09
d2d531c
f0e3826
5655ddf
9f07b10
37a0963
406d621
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,12 @@ | |
SequenceClassifierOutput, | ||
TokenClassifierOutput, | ||
) | ||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer | ||
from .modeling_utils import ( | ||
PreTrainedModel, | ||
apply_chunking_to_forward, | ||
find_pruneable_heads_and_indices, | ||
prune_linear_layer, | ||
) | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -88,6 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): | |
""" | ||
try: | ||
import re | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
except ImportError: | ||
|
@@ -376,6 +382,8 @@ def forward(self, hidden_states, input_tensor): | |
class BertLayer(nn.Module): | ||
def __init__(self, config): | ||
super().__init__() | ||
self.chunk_size_feed_forward = config.chunk_size_feed_forward | ||
self.seq_len_dim = 1 | ||
self.attention = BertAttention(config) | ||
self.is_decoder = config.is_decoder | ||
if self.is_decoder: | ||
|
@@ -410,11 +418,17 @@ def forward( | |
attention_output = cross_attention_outputs[0] | ||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights | ||
|
||
intermediate_output = self.intermediate(attention_output) | ||
layer_output = self.output(intermediate_output, attention_output) | ||
layer_output = apply_chunking_to_forward( | ||
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output | ||
) | ||
Comment on lines
+421
to
+423
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very much a nitpick here, for future PRs probably, but this looks a lot like the gradient checkpointing method from PyTorch. This method takes the callable (the forward) method as first positional argument and I think it makes sense to have it this way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can do this globally in the new PR where I add the chunking for other models. Let me know if you have concerns with that. |
||
outputs = (layer_output,) + outputs | ||
return outputs | ||
|
||
def feed_forward_chunk(self, attention_output): | ||
intermediate_output = self.intermediate(attention_output) | ||
layer_output = self.output(intermediate_output, attention_output) | ||
return layer_output | ||
|
||
|
||
class BertEncoder(nn.Module): | ||
def __init__(self, config): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To solve the problem, my suggestions would be to wrap these two calls in a function
forward_chunk
which is part of this class (def forward_chunk(self, ....)
) and callapply_chunking_to_forward(self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output,)
here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think I quite follow what you mean here. Which two calls do you want to wrap?
Did you mean to have a
forward_chunk
function in theBertLayer
class?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I fixed it based on your input - looks ok to me now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, that's exactly what I meant :-)