Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feed forward chunking #6024

Merged
merged 14 commits into from
Aug 11, 2020
Prev Previous commit
Next Next commit
fix docs
  • Loading branch information
patrickvonplaten committed Aug 8, 2020
commit 37a0963637bed7fd011784bc6bb93cd10e761d31
7 changes: 0 additions & 7 deletions src/transformers/configuration_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ class ReformerConfig(PretrainedConfig):
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
chunk_size_feed_forward (:obj:`int`, optional, defaults to 0):
The chunk size of all feed forward layers in the residual attention blocks.
A chunk size of 0 means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes n < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
eos_token_id (:obj:`int`, optional, defaults to 2):
The token id for the <EOS> token.
feed_forward_size (:obj:`int`, optional, defaults to 512):
Expand Down Expand Up @@ -147,7 +142,6 @@ def __init__(
axial_pos_shape=[64, 64],
axial_pos_embds_dim=[64, 192],
chunk_size_lm_head=0,
chunk_size_feed_forward=0,
eos_token_id=2,
feed_forward_size=512,
hash_seed=None,
Expand Down Expand Up @@ -202,5 +196,4 @@ def __init__(
self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
self.axial_norm_std = axial_norm_std
self.chunk_size_lm_head = chunk_size_lm_head
self.chunk_size_feed_forward = chunk_size_feed_forward
self.attn_layers = attn_layers
5 changes: 5 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class PretrainedConfig(object):
2.
xla_device (:obj:`bool`, `optional`):
A flag to indicate if TPU are available or not.
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
The chunk size of all feed forward layers in the residual attention blocks.
A chunk size of :obj:`0` means that the feed forward layer is not chunked.
A chunk size of n means that the feed forward layer processes :obj:`n` < sequence_length embeddings at a time.
For more information on feed forward chunking, see `How does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .

Parameters for sequence generation
- **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
Expand Down
8 changes: 2 additions & 6 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,19 +530,15 @@ def test_feed_forward_chunking(self):
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)

if self.model_tester.is_training is False:
model.eval()
model.eval()

hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]

torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
model.to(torch_device)

if self.model_tester.is_training is False:
model.eval()
model.eval()

hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
Expand Down