Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TrOCR + VisionEncoderDecoderModel #13874

Merged
merged 38 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
26b79a6
First draft
NielsRogge Sep 25, 2021
eebfafa
Update self-attention of RoBERTa as proposition
NielsRogge Sep 29, 2021
adf1cb3
Improve conversion script
NielsRogge Sep 30, 2021
be7ec13
Add TrOCR decoder-only model
NielsRogge Sep 30, 2021
1ec88d5
More improvements
NielsRogge Sep 30, 2021
7ded83b
Make forward pass with pretrained weights work
NielsRogge Sep 30, 2021
9b4189f
More improvements
NielsRogge Sep 30, 2021
9b6f68b
Some more improvements
NielsRogge Sep 30, 2021
1127064
More improvements
NielsRogge Sep 30, 2021
ac5440d
Make conversion work
NielsRogge Oct 3, 2021
6c5d947
Clean up print statements
NielsRogge Oct 4, 2021
b54e32e
Add documentation, processor
NielsRogge Oct 4, 2021
d47b5f1
Add test files
NielsRogge Oct 4, 2021
b1a85a6
Small improvements
NielsRogge Oct 4, 2021
76f3a66
Some more improvements
NielsRogge Oct 4, 2021
1d8ed6b
Make fix-copies, improve docs
NielsRogge Oct 4, 2021
2c4337e
Make all vision encoder decoder model tests pass
NielsRogge Oct 4, 2021
cc4eb2c
Make conversion script support other models
NielsRogge Oct 5, 2021
170f905
Update URL for OCR image
NielsRogge Oct 5, 2021
28bdf18
Update conversion script
NielsRogge Oct 5, 2021
890dd70
Fix style & quality
NielsRogge Oct 5, 2021
15f797d
Add support for the large-printed model
NielsRogge Oct 5, 2021
f490e3a
Fix some issues
NielsRogge Oct 6, 2021
2230eb0
Add print statement for debugging
NielsRogge Oct 6, 2021
f8ad61d
Add print statements for debugging
NielsRogge Oct 6, 2021
e5f6983
Make possible fix for sinusoidal embedding
NielsRogge Oct 6, 2021
643c21d
Further debugging
NielsRogge Oct 6, 2021
b7c5bf8
Potential fix v2
NielsRogge Oct 6, 2021
6c4435d
Add more print statements for debugging
NielsRogge Oct 6, 2021
1a6825f
Add more print statements for debugging
NielsRogge Oct 6, 2021
667b03c
Deubg more
NielsRogge Oct 6, 2021
bf49483
Comment out print statements
NielsRogge Oct 6, 2021
f0c8b59
Make conversion of large printed model possible, address review comments
NielsRogge Oct 8, 2021
6f1d7fa
Make it possible to convert the stage1 checkpoints
NielsRogge Oct 8, 2021
c38904b
Clean up code, apply suggestions from code review
NielsRogge Oct 8, 2021
6e6b947
Apply suggestions from code review, use Microsoft models in tests
NielsRogge Oct 11, 2021
b1fedab
Rename encoder_hidden_size to cross_attention_hidden_size
NielsRogge Oct 11, 2021
f3d9e94
Improve docs
NielsRogge Oct 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Make all vision encoder decoder model tests pass
  • Loading branch information
NielsRogge committed Oct 4, 2021
commit 2c4337eb14bada93dd6c57aab274c41c95bfce13
13 changes: 10 additions & 3 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,9 @@
"load_tf_weights_in_transfo_xl",
]
)
_import_structure["models.trocr"].extend(["TrOCRForCausalLM", "TrOCRPreTrainedModel"])
_import_structure["models.trocr"].extend(
["TROCR_PRETRAINED_MODEL_ARCHIVE_LIST", "TrOCRForCausalLM", "TrOCRPreTrainedModel"]
)
_import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"])
_import_structure["models.visual_bert"].extend(
[
Expand Down Expand Up @@ -2096,7 +2098,7 @@
TransfoXLCorpus,
TransfoXLTokenizer,
)
from .models.trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig
from .models.trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig, TrOCRProcessor
from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
Expand Down Expand Up @@ -2854,7 +2856,12 @@
TransfoXLPreTrainedModel,
load_tf_weights_in_transfo_xl,
)
from .models.trocr import TROCR_PRETRAINED_MODEL_ARCHIVE_LIST, TrOCRModel, TrOCRPreTrainedModel, TrOCRProcessor
from .models.trocr import (
TROCR_PRETRAINED_MODEL_ARCHIVE_LIST,
TrOCRForCausalLM,
TrOCRPreTrainedModel,
TrOCRProcessor,
)
from .models.vision_encoder_decoder import VisionEncoderDecoderModel
from .models.visual_bert import (
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ class PreTrainedModel
def forward(
self,
pixel_values=None,
attention_mask=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't used anywhere no? Is it just here since attention_mask is often passed in generate()?

head_mask=None,
output_attentions=None,
output_hidden_states=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def prepare_img():
url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02-00.jpg" # industry
# url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02-12.jpg" # have
# url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02-10.jpg" # let
url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" #
# url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" #
# url = "https://fki.tic.heia-fr.ch/static/img/a01-122.jpg"
im = Image.open(requests.get(url, stream=True).raw).convert("RGB")
return im
Expand Down Expand Up @@ -144,19 +144,20 @@ def convert_tr_ocr_checkpoint(checkpoint_url, pytorch_dump_folder_path):
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
model.eval()

# load state_dict of original model, remove and rename some keys
# load state_dict of original model, rename some keys
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"]

rename_keys = create_rename_keys(encoder_config, decoder_config)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, encoder_config)

# remove parameters we don't need
del state_dict["encoder.deit.head.weight"]
del state_dict["encoder.deit.head.bias"]
del state_dict["decoder.version"]

# add prefix to decoder keys (v2)
# add prefix to decoder keys
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("decoder") and "output_projection" not in key:
Expand All @@ -176,13 +177,17 @@ def convert_tr_ocr_checkpoint(checkpoint_url, pytorch_dump_folder_path):
generated_ids = model.generate(input_ids=pixel_values, num_beams=5)
print(tokenizer.decode(generated_ids[0]))

# forward pass
# decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]])
# outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
# logits = outputs.logits
# TODO verify logits
# print("Shape of logits:", logits.shape)
# print("First elements of logits:", logits[0,0,:10])
# verify logits
decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]])
outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
logits = outputs.logits
expected_shape = torch.Size([1, 1, 50265])
expected_slice = torch.tensor(
[-1.4502, -4.6683, -0.5347, -2.9291, 9.1435, -3.0571, 8.9764, 1.7560, 8.7358, -1.5311]
)

assert logits.shape == expected_shape, "Shape of logits not as expected"
assert torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-3), "First elements of logits not as expected"

Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
Expand Down
12 changes: 0 additions & 12 deletions src/transformers/models/vision_encoder_decoder/test.py

This file was deleted.

120 changes: 89 additions & 31 deletions tests/test_modeling_vision_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .test_modeling_bert import BertModelTester
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from .test_modeling_deit import DeiTModelTester
from .test_modeling_trocr import TrOCRStandaloneDecoderModelTester
from .test_modeling_vit import ViTModelTester

Expand All @@ -32,6 +33,7 @@

from transformers import (
BertLMHeadModel,
DeiTModel,
TrOCRForCausalLM,
VisionEncoderDecoderConfig,
VisionEncoderDecoderModel,
Expand Down Expand Up @@ -256,8 +258,6 @@ def check_encoder_decoder_model_output_attentions(
output_attentions=True,
)

inputs = pixel_values

encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)

Expand Down Expand Up @@ -358,12 +358,12 @@ def test_real_model_save_load_from_pretrained(self):


@require_torch
class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model_and_inputs(self):
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"google/vit-base-patch16-224-in21k", "bert-base-cased"
"hf-internal-testing/tiny-random-deit", "hf-internal-testing/tiny-random-roberta"
)
batch_size = 1
batch_size = 13
pixel_values = floats_tensor(
[
batch_size,
Expand All @@ -372,7 +372,9 @@ def get_pretrained_model_and_inputs(self):
model.encoder.config.image_size,
]
)
attention_mask = random_attention_mask([batch_size, 512])
# for DEiT, the sequence length is equal to the number of patches + 2 (for the [CLS] and distillation tokens)
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 2
attention_mask = random_attention_mask([batch_size, seq_len])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = {
Expand All @@ -384,21 +386,79 @@ def get_pretrained_model_and_inputs(self):

return model, inputs

def check_encoder_decoder_model_output_attentions(
self,
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)

encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)

# in DEiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens)
image_size = to_2tuple(encoder_model.config.image_size)
patch_size = to_2tuple(encoder_model.config.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 2
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads, seq_len, seq_len))

decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)

self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)

cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)

cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len),
)

def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = ViTModel(config).eval()
encoder_model = DeiTModel(config).eval()
decoder_model = BertLMHeadModel(decoder_config).eval()
return encoder_model, decoder_model

def prepare_config_and_inputs(self):
bert_model_tester = BertModelTester(self)
vit_model_tester = ViTModelTester(self)
encoder_config_and_inputs = vit_model_tester.prepare_config_and_inputs()
deit_model_tester = DeiTModelTester(self)
encoder_config_and_inputs = deit_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
pixel_values,
input_mask,
_,
) = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models
(
decoder_config,
decoder_input_ids,
Expand Down Expand Up @@ -429,13 +489,23 @@ def prepare_config_and_inputs(self):


@require_torch
class Vision2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model_and_inputs(self):
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"google/vit-base-patch16-224-in21k", "bert-base-cased"
"hf-internal-testing/tiny-random-vit", "hf-internal-testing/tiny-bert"
)
batch_size = 13
attention_mask = random_attention_mask([batch_size, 7])
pixel_values = floats_tensor(
[
batch_size,
model.encoder.config.num_channels,
model.encoder.config.image_size,
model.encoder.config.image_size,
]
)
# for ViT, the sequence length is equal to the number of patches + 1 (for the [CLS] token)
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 1
attention_mask = random_attention_mask([batch_size, seq_len])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = {
Expand All @@ -448,7 +518,7 @@ def get_pretrained_model_and_inputs(self):
return model, inputs

def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = Vision2TextEncoder(config).eval()
encoder_model = ViTModel(config).eval()
decoder_model = BertLMHeadModel(decoder_config).eval()
return encoder_model, decoder_model

Expand All @@ -458,9 +528,8 @@ def prepare_config_and_inputs(self):
encoder_config_and_inputs = vit_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()

config, inputs = encoder_config_and_inputs
pixel_values = inputs["pixel_values"]
input_mask = inputs["attention_mask"]
config, pixel_values, _ = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models

(
decoder_config,
Expand Down Expand Up @@ -490,18 +559,6 @@ def prepare_config_and_inputs(self):
"labels": decoder_token_labels,
}

# can't save full model for now because Speech2TextModel != Speech2TextEncoder
def test_encoder_decoder_model_from_pretrained_configs(self):
pass

# can't save full model for now because Speech2TextModel != Speech2TextEncoder
def test_save_and_load_from_pretrained(self):
pass

# all published pretrained models are Speech2TextModel != Speech2TextEncoder
def test_real_model_save_load_from_pretrained(self):
pass


@require_torch
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
Expand All @@ -520,8 +577,9 @@ def prepare_config_and_inputs(self):
(
config,
pixel_values,
input_mask,
_,
) = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs

# make sure that cross attention layers are added
Expand All @@ -537,6 +595,6 @@ def prepare_config_and_inputs(self):
"decoder_attention_mask": decoder_attention_mask,
}

# there are no published pretrained Speech2Text2ForCausalLM for now
# there are no published pretrained TrOCR checkpoints for now
def test_real_model_save_load_from_pretrained(self):
pass