Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TrOCR + VisionEncoderDecoderModel #13874

Merged
merged 38 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
26b79a6
First draft
NielsRogge Sep 25, 2021
eebfafa
Update self-attention of RoBERTa as proposition
NielsRogge Sep 29, 2021
adf1cb3
Improve conversion script
NielsRogge Sep 30, 2021
be7ec13
Add TrOCR decoder-only model
NielsRogge Sep 30, 2021
1ec88d5
More improvements
NielsRogge Sep 30, 2021
7ded83b
Make forward pass with pretrained weights work
NielsRogge Sep 30, 2021
9b4189f
More improvements
NielsRogge Sep 30, 2021
9b6f68b
Some more improvements
NielsRogge Sep 30, 2021
1127064
More improvements
NielsRogge Sep 30, 2021
ac5440d
Make conversion work
NielsRogge Oct 3, 2021
6c5d947
Clean up print statements
NielsRogge Oct 4, 2021
b54e32e
Add documentation, processor
NielsRogge Oct 4, 2021
d47b5f1
Add test files
NielsRogge Oct 4, 2021
b1a85a6
Small improvements
NielsRogge Oct 4, 2021
76f3a66
Some more improvements
NielsRogge Oct 4, 2021
1d8ed6b
Make fix-copies, improve docs
NielsRogge Oct 4, 2021
2c4337e
Make all vision encoder decoder model tests pass
NielsRogge Oct 4, 2021
cc4eb2c
Make conversion script support other models
NielsRogge Oct 5, 2021
170f905
Update URL for OCR image
NielsRogge Oct 5, 2021
28bdf18
Update conversion script
NielsRogge Oct 5, 2021
890dd70
Fix style & quality
NielsRogge Oct 5, 2021
15f797d
Add support for the large-printed model
NielsRogge Oct 5, 2021
f490e3a
Fix some issues
NielsRogge Oct 6, 2021
2230eb0
Add print statement for debugging
NielsRogge Oct 6, 2021
f8ad61d
Add print statements for debugging
NielsRogge Oct 6, 2021
e5f6983
Make possible fix for sinusoidal embedding
NielsRogge Oct 6, 2021
643c21d
Further debugging
NielsRogge Oct 6, 2021
b7c5bf8
Potential fix v2
NielsRogge Oct 6, 2021
6c4435d
Add more print statements for debugging
NielsRogge Oct 6, 2021
1a6825f
Add more print statements for debugging
NielsRogge Oct 6, 2021
667b03c
Deubg more
NielsRogge Oct 6, 2021
bf49483
Comment out print statements
NielsRogge Oct 6, 2021
f0c8b59
Make conversion of large printed model possible, address review comments
NielsRogge Oct 8, 2021
6f1d7fa
Make it possible to convert the stage1 checkpoints
NielsRogge Oct 8, 2021
c38904b
Clean up code, apply suggestions from code review
NielsRogge Oct 8, 2021
6e6b947
Apply suggestions from code review, use Microsoft models in tests
NielsRogge Oct 11, 2021
b1fedab
Rename encoder_hidden_size to cross_attention_hidden_size
NielsRogge Oct 11, 2021
f3d9e94
Improve docs
NielsRogge Oct 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
First draft
  • Loading branch information
NielsRogge committed Oct 3, 2021
commit 26b79a6f941a71c832a818f957dcd962ddc0b1d5
6 changes: 5 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@
"TransfoXLCorpus",
"TransfoXLTokenizer",
],
"models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
"models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
"models.wav2vec2": [
Expand Down Expand Up @@ -1170,6 +1171,7 @@
"load_tf_weights_in_transfo_xl",
]
)
_import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"])
_import_structure["models.visual_bert"].extend(
[
"VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -2088,6 +2090,7 @@
TransfoXLCorpus,
TransfoXLTokenizer,
)
from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from .models.wav2vec2 import (
Expand Down Expand Up @@ -2844,7 +2847,8 @@
TransfoXLPreTrainedModel,
load_tf_weights_in_transfo_xl,
)
from .models.visual_bert import ( # load_tf_weights_in_visual_bert,
from .models.vision_encoder_decoder import VisionEncoderDecoderModel
from .models.visual_bert import (
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
VisualBertForMultipleChoice,
VisualBertForPreTraining,
Expand Down
40 changes: 40 additions & 0 deletions src/transformers/models/vision_encoder_decoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_torch_available


_import_structure = {
"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
}

if is_torch_available():
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]

if TYPE_CHECKING:
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig

if is_torch_available():
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import AutoConfig


logger = logging.get_logger(__name__)


class VisionEncoderDecoderConfig(PretrainedConfig):
r"""
:class:`~transformers.VisionEncoderDecoderConfig` is the configuration class to store the configuration of a
:class:`~transformers.VisionEncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to
the specified arguments, defining the encoder and decoder configs.

Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.

Args:
kwargs (`optional`):
Dictionary of keyword arguments. Notably:

- **encoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration
object that defines the encoder config.
- **decoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration
object that defines the decoder config.

Examples::

>>> from transformers import BertConfig, ViTConfig, VisionEncoderDecoderConfig, VisionEncoderDecoderModel

>>> # Initializing a ViT & BERT style configuration
>>> config_encoder = ViTConfig()
>>> config_decoder = BertConfig()

>>> config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

>>> # Initializing a ViTBert model from a ViT & bert-base-uncased style configurations
>>> model = VisionEncoderDecoderModel(config=config)

>>> # Accessing the model configuration
>>> config_encoder = model.config.encoder
>>> config_decoder = model.config.decoder
>>> # set decoder config to causal lm
>>> config_decoder.is_decoder = True
>>> config_decoder.add_cross_attention = True

>>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model')

>>> # loading model and config from pretrained folder
>>> encoder_decoder_config = VisionEncoderDecoderConfig.from_pretrained('my-model')
>>> model = VisionEncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config)
"""
model_type = "vision-encoder-decoder"
is_composition = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
if "encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError(
f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
)

encoder_config = kwargs.pop("encoder")
encoder_model_type = encoder_config.pop("model_type")
decoder_config = kwargs.pop("decoder")
decoder_model_type = decoder_config.pop("model_type")

self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
self.is_encoder_decoder = True

@classmethod
def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
r"""
Instantiate a :class:`~transformers.VisionEncoderDecoderConfig` (or a derived class) from a pre-trained encoder
model configuration and decoder model configuration.

Returns:
:class:`VisionEncoderDecoderConfig`: An instance of a configuration object
"""
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True

return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.

Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert TrOCR checkpoints from the unilm repository."""


import argparse
from pathlib import Path

import torch
from PIL import Image

import requests
from transformers import (
RobertaConfig,
RobertaModel,
VisionEncoderDecoderModel,
ViTConfig,
ViTFeatureExtractor,
ViTModel,
)
from transformers.utils import logging


logging.set_verbosity_info()
logger = logging.get_logger(__name__)


# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config):
rename_keys = []
for i in range(config.num_hidden_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append((f"blocks.{i}.norm1.weight", f"deit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"deit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"deit.encoder.layer.{i}.attention.output.dense.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"deit.encoder.layer.{i}.attention.output.dense.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"deit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"deit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"deit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"deit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"deit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"deit.encoder.layer.{i}.output.dense.bias"))

# projection layer + position embeddings
rename_keys.extend(
[
("cls_token", "deit.embeddings.cls_token"),
("patch_embed.proj.weight", "deit.embeddings.patch_embeddings.projection.weight"),
("patch_embed.proj.bias", "deit.embeddings.patch_embeddings.projection.bias"),
]
)


# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config):
for i in range(config.num_hidden_layers):
prefix = "deit."
# queries, keys and values
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"blocks.{i}.attn.v_bias")

state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias

# gamma_1 and gamma_2
# we call them lambda because otherwise they are renamed when using .from_pretrained
gamma_1 = state_dict.pop(f"blocks.{i}.gamma_1")
gamma_2 = state_dict.pop(f"blocks.{i}.gamma_2")

state_dict[f"{prefix}encoder.layer.{i}.lambda_1"] = gamma_1
state_dict[f"{prefix}encoder.layer.{i}.lambda_2"] = gamma_2


def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val


# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im


@torch.no_grad()
def convert_tr_ocr_checkpoint(checkpoint_url, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our VisionEncoderDecoderModel structure.
"""
# define encoder and decoder configs based on checkpoint_url
encoder_config = ViTConfig()
decoder_config = RobertaConfig.from_pretrained("roberta-large")

# size of the architecture
if "base" in checkpoint_url:
pass
elif "large" in checkpoint_url:
encoder_config.hidden_size = 1024
encoder_config.intermediate_size = 4096
encoder_config.num_hidden_layers = 24
encoder_config.num_attention_heads = 16
else:
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")

# load HuggingFace model
encoder = ViTModel(encoder_config)
decoder = RobertaModel(decoder_config)
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
model.eval()

# load state_dict of original model, remove and rename some keys
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"]
rename_keys = create_rename_keys(config)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config)

# load state dict
model.load_state_dict(state_dict)

# Check outputs on an image
feature_extractor = ViTFeatureExtractor(
size=encoder_config.image_size, resample=Image.BILINEAR, do_center_crop=False
)
encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
pixel_values = encoding["pixel_values"]

outputs = model(pixel_values)
logits = outputs.logits

# TODO verify logits
expected_shape = torch.Size([1, 1000])
assert logits.shape == expected_shape, "Shape of logits not as expected"

Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
feature_extractor.save_pretrained(pytorch_dump_folder_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--checkpoint_url",
default="https://layoutlm.blob.core.windows.net/trocr/model_zoo/fairseq/trocr-base-handwritten.pt",
type=str,
help="URL to the original PyTorch checkpoint (.pth file).",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
)
args = parser.parse_args()
convert_tr_ocr_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
Loading