Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TrOCR + VisionEncoderDecoderModel #13874

Merged
merged 38 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
26b79a6
First draft
NielsRogge Sep 25, 2021
eebfafa
Update self-attention of RoBERTa as proposition
NielsRogge Sep 29, 2021
adf1cb3
Improve conversion script
NielsRogge Sep 30, 2021
be7ec13
Add TrOCR decoder-only model
NielsRogge Sep 30, 2021
1ec88d5
More improvements
NielsRogge Sep 30, 2021
7ded83b
Make forward pass with pretrained weights work
NielsRogge Sep 30, 2021
9b4189f
More improvements
NielsRogge Sep 30, 2021
9b6f68b
Some more improvements
NielsRogge Sep 30, 2021
1127064
More improvements
NielsRogge Sep 30, 2021
ac5440d
Make conversion work
NielsRogge Oct 3, 2021
6c5d947
Clean up print statements
NielsRogge Oct 4, 2021
b54e32e
Add documentation, processor
NielsRogge Oct 4, 2021
d47b5f1
Add test files
NielsRogge Oct 4, 2021
b1a85a6
Small improvements
NielsRogge Oct 4, 2021
76f3a66
Some more improvements
NielsRogge Oct 4, 2021
1d8ed6b
Make fix-copies, improve docs
NielsRogge Oct 4, 2021
2c4337e
Make all vision encoder decoder model tests pass
NielsRogge Oct 4, 2021
cc4eb2c
Make conversion script support other models
NielsRogge Oct 5, 2021
170f905
Update URL for OCR image
NielsRogge Oct 5, 2021
28bdf18
Update conversion script
NielsRogge Oct 5, 2021
890dd70
Fix style & quality
NielsRogge Oct 5, 2021
15f797d
Add support for the large-printed model
NielsRogge Oct 5, 2021
f490e3a
Fix some issues
NielsRogge Oct 6, 2021
2230eb0
Add print statement for debugging
NielsRogge Oct 6, 2021
f8ad61d
Add print statements for debugging
NielsRogge Oct 6, 2021
e5f6983
Make possible fix for sinusoidal embedding
NielsRogge Oct 6, 2021
643c21d
Further debugging
NielsRogge Oct 6, 2021
b7c5bf8
Potential fix v2
NielsRogge Oct 6, 2021
6c4435d
Add more print statements for debugging
NielsRogge Oct 6, 2021
1a6825f
Add more print statements for debugging
NielsRogge Oct 6, 2021
667b03c
Deubg more
NielsRogge Oct 6, 2021
bf49483
Comment out print statements
NielsRogge Oct 6, 2021
f0c8b59
Make conversion of large printed model possible, address review comments
NielsRogge Oct 8, 2021
6f1d7fa
Make it possible to convert the stage1 checkpoints
NielsRogge Oct 8, 2021
c38904b
Clean up code, apply suggestions from code review
NielsRogge Oct 8, 2021
6e6b947
Apply suggestions from code review, use Microsoft models in tests
NielsRogge Oct 11, 2021
b1fedab
Rename encoder_hidden_size to cross_attention_hidden_size
NielsRogge Oct 11, 2021
f3d9e94
Improve docs
NielsRogge Oct 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Clean up code, apply suggestions from code review
  • Loading branch information
NielsRogge committed Oct 8, 2021
commit c38904bff8c70dcf8f1badc6d56d4d8ffef4a8c8
2 changes: 1 addition & 1 deletion docs/source/model_doc/visionencoderdecoder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Vision Encoder Decoder Models

The :class:`~transformers.VisionEncoderDecoderModel` can be used to initialize an image-to-text-sequence model with any
pretrained vision autoencoding model as the encoder (*e.g.* :doc:`ViT <vit>`, :doc:`BEiT <beit>`, :doc:`DeiT <deit>`)
and any pretrained language model as the decoder (*e.g.* :doc: `RoBERTa <roberta>`, :doc: `GPT2 <gpt2>`, :doc: `BERT
and any pretrained language model as the decoder (*e.g.* :doc:`RoBERTa <roberta>`, :doc:`GPT2 <gpt2>`, :doc:`BERT
<bert>`).

The effectiveness of initializing image-to-text-sequence models with pretrained checkpoints has been shown in (for
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/trocr/configuration_trocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class TrOCRConfig(PretrainedConfig):
The dropout ratio for classifier.
init_std (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
https://arxiv.org/abs/1909.11556>`__ for more details.
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
Expand Down
33 changes: 4 additions & 29 deletions src/transformers/models/trocr/modeling_trocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,6 @@ class TrOCRSinusoidalPositionalEmbedding(nn.Module):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__()
self.offset = 2

# print("Embedding dim:", embedding_dim)
# print("Padding idx:", padding_idx)

self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = self.get_embedding(num_positions, embedding_dim, padding_idx)
Expand All @@ -129,9 +125,6 @@ def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional
if padding_idx is not None:
emb[padding_idx, :] = 0

# print("Shape of emb in get_embedding:", emb.shape)
# print("First elements of emb in get_embedding:", emb[:3, :3])

return emb

@torch.no_grad()
Expand All @@ -144,20 +137,13 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):

# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len

# print("Seq_len:", seq_len)
# print("Max_pos:", max_pos)

if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
self.weights = self.weights.to(self._float_tensor)

x = self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()

# print("Shape of position embeddings:", x.shape)
# print("First elements of position embeddings:", x[0, 0, :3])

return x

def create_position_ids_from_input_ids(
Expand Down Expand Up @@ -199,9 +185,10 @@ def __init__(
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
if not (self.head_dim * num_heads == self.embed_dim):
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder

Expand Down Expand Up @@ -651,29 +638,17 @@ def forward(
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

# print("Input_ids:", input_ids)
# print("Weights of embed tokens:", self.embed_tokens.weight[:3,:3])

if inputs_embeds is None:
# y = self.embed_tokens(input_ids)
# print("First elements of embeddings before embed_scale:", y[0, :3, :3])
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

# print("First elements of embeddings after embed_scale, before position embeddings:", inputs_embeds[0, :3, :3])
# print("Embed scale:", self.embed_scale)

if self.config.use_learned_position_embeddings:
embed_pos = self.embed_positions(input_shape, past_key_values_length=past_key_values_length)
else:
embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)

hidden_states = inputs_embeds + embed_pos

# print("Shape of embeddings after position embeddings:", hidden_states.shape)
# print("First elements of embeddings after position embeddings:", hidden_states[0, :3, :3])

if self.layernorm_embedding is not None:
# print("Adding layernorm to the embeddings")
hidden_states = self.layernorm_embedding(hidden_states)

hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/trocr/processing_trocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class TrOCRProcessor:
:meth:`~transformers.TrOCRProcessor.decode` for more information.

Args:
feature_extractor (:obj:`AutoFeatureExtractor`):
feature_extractor (:class:`~transformers.AutoFeatureExtractor`):
An instance of :class:`~transformers.AutoFeatureExtractor`. The feature extractor is a required input.
tokenizer (:obj:`RobertaTokenizer`):
tokenizer (:class:`~transformers.RobertaTokenizer`):
An instance of :class:`~transformers.RobertaTokenizer`. The tokenizer is a required input.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
if "encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError(
f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
f"A configuraton of type {self.model_type} cannot be instantiated because "
f"not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
)

encoder_config = kwargs.pop("encoder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,39 +178,21 @@ def convert_tr_ocr_checkpoint(checkpoint_url, pytorch_dump_folder_path):
else:
state_dict[key] = val

print("Embed tokens of model before loading state dict:")
print(model.decoder.model.decoder.embed_tokens.weight[:3, :3])

print("Shape of embed tokens in state dict:", state_dict["decoder.model.decoder.embed_tokens.weight"].shape)
print("Embed tokens in state dict:")
print(state_dict["decoder.model.decoder.embed_tokens.weight"][:3, :3])

# load state dict
model.load_state_dict(state_dict)

print("Shape of embed tokens in model:", model.decoder.model.decoder.embed_tokens.weight.shape)
print("Embed tokens of model:")
print(model.decoder.model.decoder.embed_tokens.weight[:3, :3])

# Check outputs on an image
feature_extractor = ViTFeatureExtractor(size=encoder_config.image_size)
tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
processor = TrOCRProcessor(feature_extractor, tokenizer)

pixel_values = processor(images=prepare_img(checkpoint_url), return_tensors="pt").pixel_values

print("First elements of pixel values:", pixel_values[0, 0, :3, :3])

# generated_ids = model.generate(input_ids=pixel_values, num_beams=5)
# print(processor.batch_decode(generated_ids, skip_special_tokens=True)[0])

# verify logits
decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]])
outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
logits = outputs.logits

print("First elements of logits:", logits[0, 0, :10])

expected_shape = torch.Size([1, 1, 50265])
if "trocr-base-handwritten" in checkpoint_url:
expected_slice = torch.tensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ def set_output_embeddings(self, new_embeddings):

@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported
# for composite models
# At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False):
logger.warning(
"Fast initialization is currently not supported for VisionEncoderDecoderModel. Falling back to slow intialization..."
Expand Down Expand Up @@ -317,7 +316,8 @@ def from_encoder_decoder_pretrained(
if encoder is None:
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. In this case make sure that `encoder_pretrained_model_name_or_path` defined"
f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. "
f"In this case make sure that `encoder_pretrained_model_name_or_path` defined"
)

if "config" not in kwargs_encoder:
Expand Down Expand Up @@ -360,7 +360,8 @@ def from_encoder_decoder_pretrained(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder."
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config`"
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` "
f"to `.from_encoder_decoder_pretrained(...)`"
)

decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
Expand Down Expand Up @@ -445,11 +446,6 @@ def forward(
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)

# compute correct encoder attention mask
# if attention_mask is not None:
# encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
# encoder_hidden_states.shape[1], attention_mask
# )
# else:
encoder_attention_mask = None

Expand Down
25 changes: 5 additions & 20 deletions tests/test_modeling_trocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,7 @@ def prepare_config_and_inputs(self):
max_position_embeddings=self.max_position_embeddings,
)

return (
config,
input_ids,
attention_mask,
lm_labels,
)
return (config, input_ids, attention_mask, lm_labels)

def create_and_check_decoder_model_past(
self,
Expand Down Expand Up @@ -156,17 +151,9 @@ def create_and_check_decoder_model_past(

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
lm_labels,
) = config_and_inputs

inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
config, input_ids, attention_mask, lm_labels = config_and_inputs

inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict


Expand All @@ -176,9 +163,7 @@ class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, u
all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else ()
test_pruning = False

def setUp(
self,
):
def setUp(self):
self.model_tester = TrOCRStandaloneDecoderModelTester(self, is_training=False)
self.config_tester = ConfigTester(self, config_class=TrOCRConfig)

Expand Down
20 changes: 4 additions & 16 deletions tests/test_modeling_vision_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import tempfile
import unittest

from datasets import load_dataset

from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device

Expand Down Expand Up @@ -459,11 +461,7 @@ def prepare_config_and_inputs(self):
deit_model_tester = DeiTModelTester(self)
encoder_config_and_inputs = deit_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
pixel_values,
_,
) = encoder_config_and_inputs
config, pixel_values, _ = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models
(
decoder_config,
Expand Down Expand Up @@ -580,11 +578,7 @@ def prepare_config_and_inputs(self):
)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
(
config,
pixel_values,
_,
) = encoder_config_and_inputs
config, pixel_values, _ = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs

Expand Down Expand Up @@ -619,10 +613,7 @@ def test_inference_handwritten(self):
# TODO update to microsoft
model = VisionEncoderDecoderModel.from_pretrained("nielsr/trocr-base-handwritten").to(torch_device)

from datasets import load_dataset

ds = load_dataset("hf-internal-testing/fixtures_ocr", split="test")

image = Image.open(ds[0]["file"]).convert("RGB")

processor = self.default_processor
Expand All @@ -648,10 +639,7 @@ def test_inference_printed(self):
# TODO update to microsoft
model = VisionEncoderDecoderModel.from_pretrained("nielsr/trocr-base-printed").to(torch_device)

from datasets import load_dataset

ds = load_dataset("hf-internal-testing/fixtures_ocr", split="test")

image = Image.open(ds[1]["file"]).convert("RGB")

processor = self.default_processor
Expand Down