Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feed forward chunking #6024

Merged
merged 14 commits into from
Aug 11, 2020
Prev Previous commit
Next Next commit
Feed forward chunking in BertLayer class.
  • Loading branch information
Pradhy729 committed Aug 7, 2020
commit f0e3826eef61c1f1705eec046918fb1c3424df95
2 changes: 1 addition & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, **kwargs):
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 1)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)

# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)
Expand Down
32 changes: 11 additions & 21 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,34 +378,17 @@ def forward(self, hidden_states, input_tensor):
return hidden_states


class ChunkFeedForward(nn.Module):
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1

self.dense = BertIntermediate(config)
self.output = BertOutput(config)

def forward(self, attention_output):
return apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.forward_chunk, attention_output,
)

def forward_chunk(self, attention_output):
intermediate_output = self.dense(attention_output)
return self.output(intermediate_output, attention_output)


class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
if self.is_decoder:
self.crossattention = BertAttention(config)

self.feed_forward = ChunkFeedForward(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)

def forward(
self,
Expand Down Expand Up @@ -434,10 +417,17 @@ def forward(
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights

layer_output = self.feed_forward(attention_output)
layer_output = apply_chunking_to_forward(
self.chunk_size_feed_forward, self.seq_len_dim, self.feed_forward_chunk, attention_output
)
Comment on lines +421 to +423
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very much a nitpick here, for future PRs probably, but this looks a lot like the gradient checkpointing method from PyTorch. This method takes the callable (the forward) method as first positional argument and I think it makes sense to have it this way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do this globally in the new PR where I add the chunking for other models. Let me know if you have concerns with that.

outputs = (layer_output,) + outputs
return outputs

def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output


class BertEncoder(nn.Module):
def __init__(self, config):
Expand Down