-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feed forward chunking #6024
Feed forward chunking #6024
Conversation
Update from source
Update from source
Update from source
Update from source
Update from source
@patrickvonplaten - here's an initial implementation I have. My first step is to get the model to work with chunked feed forward - and it works! I still need to run the benchmark test to find out the benefits in terms of memory. However, I see a problem. The new architecture causes some of the nn.Module weights and bias parameter-names to change - which would be a problem with loading existing pretrained weights from checkpoints. See the failing tests for more details. Any thoughts/ideas on how to get around this? |
@@ -159,6 +159,7 @@ def __init__(self, **kwargs): | |||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) | |||
self.bad_words_ids = kwargs.pop("bad_words_ids", None) | |||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1) | |||
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be disabled by default -> so would be nice to set it to 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And to make sure we don't break backward compatibility...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in modeling_bert.py, refer to #6972. thanks.
@@ -408,8 +432,7 @@ def forward( | |||
attention_output = cross_attention_outputs[0] | |||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights | |||
|
|||
intermediate_output = self.intermediate(attention_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To solve the problem, my suggestions would be to wrap these two calls in a function forward_chunk
which is part of this class (def forward_chunk(self, ....)
) and call apply_chunking_to_forward(self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output,)
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think I quite follow what you mean here. Which two calls do you want to wrap?
Did you mean to have a forward_chunk
function in the BertLayer
class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I fixed it based on your input - looks ok to me now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, that's exactly what I meant :-)
Update from source
ab11e31
to
5621e1f
Compare
fix the shuffle agrument usage and the default (huggingface#6307)
This is an initial implementation to test applying feed forward chunking for BERT. Will need additional modifications based on output and benchmark results.
44aea51
to
5655ddf
Compare
Codecov Report
@@ Coverage Diff @@
## master #6024 +/- ##
==========================================
- Coverage 79.44% 79.12% -0.32%
==========================================
Files 148 148
Lines 27193 27198 +5
==========================================
- Hits 21604 21521 -83
- Misses 5589 5677 +88
Continue to review full report at Codecov.
|
This comment was marked as spam.
This comment was marked as spam.
Hey @Pradhy729, thanks a lot for continuing the PR. I made a couple of changes: fix the docs and added tests for all models, whereas only Reformer and Bert tests are on for now. Would be great if @LysandreJik @sgugger @thomwolf @sshleifer can review. This PR shows how To-Do after review is positive:
|
Great! Thanks @patrickvonplaten |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very useful work, thanks for tackling this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, love how it only requires a few lines of code now that apply_chunking_to_forward
is created.
layer_output = apply_chunking_to_forward( | ||
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very much a nitpick here, for future PRs probably, but this looks a lot like the gradient checkpointing method from PyTorch. This method takes the callable (the forward) method as first positional argument and I think it makes sense to have it this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can do this globally in the new PR where I add the chunking for other models. Let me know if you have concerns with that.
This reverts commit 261b765.
Official PR for #5928