Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TrOCR + VisionEncoderDecoderModel #13874

Merged
merged 38 commits into from
Oct 13, 2021

Conversation

NielsRogge
Copy link
Contributor

@NielsRogge NielsRogge commented Oct 5, 2021

What does this PR do?

This PR adds the TrOCR models by Microsoft, together with a new VisionEncoderDecoderModel class (which should be used in order to use TrOCR, as it consists of an image encoder and an autoregressive text decoder). This PR is very similar to #13186, it's just the vision counterpart.

Here's how to use this model:

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values

generated_ids = model.generate(input_ids=pixel_values)
print(processor.batch_decode(generated_ids)[0])

There's also this Colab notebook for quick inference: https://colab.research.google.com/drive/1qCuqlqc4V9LZhPkxIi_XqCCTQrDhkmHi?usp=sharing

! A big disclaimer: the TrOCR models do not directly work on entire images of PDFs etc. They are trained on single-text-line images. One needs a text detector first, before applying TrOCR. TrOCR is a text recognition model, not a text detection model. One typically needs both in a sequence in order to extract all text from a given image.

Important note:

The current design of the existing EncoderDecoderModel/FlaxEncoderDecoderModel is that, if the hidden_size of the encoder/decoder don't match, one creates a single projection layer to project the encoder_hidden_states to the same number of channels as the decoder. However, for TrOCR, this is not how it's done. Instead, one projects the encoder_hidden_states to the same dimension as the decoder when projecting to keys and values, in each decoder layer. Therefore, my proposal is to add an attribute to the config of the decoder called encoder_hidden_size, which, if specified, will be used in the VisionEncoderDecoderModel class to not project the encoder hidden states. Instead, it will be used when instantiating the key and value projection layers.

For consistency, we could also add this to the existing EncoderDecoderModel/FlaxEncoderDecoderModel. Also relevant for the FlaxVisionEncoderDecoderModel PR, see #13359.

To do:

@LysandreJik LysandreJik requested a review from sgugger October 6, 2021 02:44
@@ -476,6 +476,7 @@ class PreTrainedModel
def forward(
self,
pixel_values=None,
attention_mask=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't used anywhere no? Is it just here since attention_mask is often passed in generate()?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks great to me! Really clean implementation that fits well with the current design of Transformers IMO - it should enable lots of image captioning tasks from pretrained ViT + BERT :-)
Also very nice that you didn't have to adapt generate() to make it work!

Just have the following points:

  • Do we need this attention_mask here: Add TrOCR + VisionEncoderDecoderModel #13874 (comment) ? I didn't dive to deep into the code, but if we need it just because generate(...) passes it then it's a bit hacky and we should try to avoid it. Happy to see how we can change generate for this or at least better add a kwargs(...) arguments that logs that the input is not used.
  • I feel quite strongly about not calling it encoder_hidden_size, but rather cross_attention_hidden_size. From a user that just looks at configuration_utils.py the name encoder_hidden-size is not at all related to encoder-decoder architectures. Can we change that maybe?
  • Can we add one slow integration test with a real model (maybe the one in your notebook?)

Overall amazing work! Think we can merge this in a couple of days :-)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this!

README.md Outdated Show resolved Hide resolved
docs/source/model_doc/visionencoderdecoder.rst Outdated Show resolved Hide resolved
src/transformers/models/trocr/configuration_trocr.py Outdated Show resolved Hide resolved
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should be expanded or removed.

src/transformers/models/trocr/modeling_trocr.py Outdated Show resolved Hide resolved
tests/test_modeling_trocr.py Outdated Show resolved Hide resolved
tests/test_modeling_vision_encoder_decoder.py Outdated Show resolved Hide resolved
tests/test_modeling_vision_encoder_decoder.py Outdated Show resolved Hide resolved
tests/test_modeling_vision_encoder_decoder.py Outdated Show resolved Hide resolved
tests/test_modeling_vision_encoder_decoder.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great. Thanks for working on this, @NielsRogge!

Comment on lines +114 to +115
Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
description in Section 3.5 of "Attention Is All You Need".
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it differ?

Copy link
Contributor Author

@NielsRogge NielsRogge Oct 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten knows, I just copied his implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same docstring is used in Speech2Text2

src/transformers/models/trocr/modeling_trocr.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants