Skip to content

Conversation

LysandreJik
Copy link
Collaborator

@LysandreJik LysandreJik commented Aug 13, 2020

This PR aims to add the mixin ModelTesterMixin to LXMERT, to ensure that it behaves correctly. In doing so there has been a few bugfixes, but also a slight refactor of LXMERT's UI. This affects only user-facing method of user-facing models (LxmertModel, LxmertForPretraining, LxmertForQuestionAnswering).

Some things to note:

  • the UI is extremely important. The UI need to align with the rest of the library or users will be lost and either open issues on a regular basis to understand, or simply not use the model at all.
  • The UI is the most important component of the model, as it cannot be easily changed over time. Once the UI is set, introducing any changes to it results in breaking changes which is a tremendous pain for users. The internals of the models may change over time, however, as long as the resulting behavior doesn't change.

The noteworthy changes are detailed in comments below for an easier review.

Only PyTorch for now; if you agree with these changes I'll do TensorFlow as well.

@LysandreJik LysandreJik requested a review from eltoto1219 August 13, 2020 07:14
Comment on lines +1 to +8
LXMERT
----------------------------------------------------

Overview
~~~~~~~~~~~~~~~~~~~~~



Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds LXMERT to the documentation. I've left the Overview blank for now.

Comment on lines +179 to +183
self.num_hidden_layers = {
"vision": r_layers,
"cross_encoder": x_layers,
"language": l_layers
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very different to what we usually do, but as it's the first multi-transformer architecture we can have a bit more freedom. I think this approach makes sense, but we'll have to discuss with other team members before the final merge in master if we decide to merge this.

Comment on lines +37 to +39
_CONFIG_FOR_DOC = "LxmertConfig"
_TOKENIZER_FOR_DOC = "LxmertTokenizer"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our documentation format has slightly changed, I've updated it here.

return gelu(x)

@dataclass
class LxmertModelOutput(ModelOutput):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the model outputs in this file, as they're model-specific. Renamed and re-ordered the parameters.

Comment on lines +61 to +89
language_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the language encoder.
vision_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the visual encoder.
pooled_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification, CLS, token)
further processed by a Linear layer and a Tanh activation function. The Linear
language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several things we have to respect with these model outputs:

  • There should be absolutely no different behavior between using the traditional tuple output and this model output, if using .to_tuple. This means that:
model(**inputs) == model(**inputs, return_dict=True).to_tuple()

We're adding a test to the common tests about this here.

  • Model outputs should therefore have the same order as the regular tuple outputs.
  • I've re-ordered it as such: CATEGORY_1[language, vision, misc], CATEGORY_2[language, vision, misc], etc.
  • The documented args should have the same order as the args defined in the model output.

Comment on lines -998 to +1122
total_loss = 0.0
total_loss = torch.tensor(0.0, device=device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We deal in tensor or tuple of tensor outputs only.

Comment on lines -1072 to +1190
self.loss = CrossEntropyLoss(ignore_index=-100)
self.loss = CrossEntropyLoss()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ignore_index is -100 by default

Comment on lines -249 to +258
self.parent.assertEqual(result.last_hidden_state_l.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.language_output.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(
result.last_hidden_state_v.shape, (self.batch_size, self.num_visual_features, self.hidden_size)
result.vision_output.shape, (self.batch_size, self.num_visual_features, self.hidden_size)
)
self.parent.assertEqual(result.pooled_output_x_encoder.shape, (self.batch_size, self.hidden_size))
self.parent.assertEqual(result.pooled_output.shape, (self.batch_size, self.hidden_size))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope you'll agree with me that this output is cleaner and simpler to understand for users :)

Comment on lines -521 to +516
class LxmertModelTest(unittest.TestCase):
class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boom!

model = LxmertModel.from_pretrained(model_name)
self.assertIsNotNone(model)

def test_attention_outputs(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to re-implement that test here, as well as the hidden states test since the behavior is different to other single-transformer models.

@eltoto1219 eltoto1219 merged this pull request into lxmert_model Aug 13, 2020
eltoto1219 pushed a commit that referenced this pull request Sep 11, 2020
* neFLOs calculation, logging, and reloading (#1)

* testing distributed consecutive batches

* fixed AttributeError from DataParallel

* removed verbosity

* rotate with use_mtime=True

* removed print

* fixed interaction with gradient accumulation

* indent formatting

* distributed neflo counting

* fixed typo

* fixed typo

* mean distributed losses

* exporting log history

* moved a few functions

* floating_point_ops clarification for transformers with parameter-reuse

* code quality

* double import

* made flo estimation more task-agnostic

* only logging flos if computed

* code quality

* unused import

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Sylvain review

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* black

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants