diff --git a/data/mm_data/ocr_dataset.py b/data/mm_data/ocr_dataset.py index cef176c9..f338cc31 100644 --- a/data/mm_data/ocr_dataset.py +++ b/data/mm_data/ocr_dataset.py @@ -9,7 +9,7 @@ import warnings import random import functools - +import numpy as np import torch import base64 from torchvision import transforms diff --git a/evaluate.py b/evaluate.py index e41d8a30..03da82e8 100644 --- a/evaluate.py +++ b/evaluate.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 -u -# Copyright 2022 The OFA-Sys Team. +# Copyright 2022 The OFA-Sys Team. # All rights reserved. -# This source code is licensed under the Apache 2.0 license +# This source code is licensed under the Apache 2.0 license # found in the LICENSE file in the root directory. import logging @@ -42,7 +42,7 @@ def main(cfg: DictConfig, **kwargs): logger.info(cfg) assert ( - cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" # Fix seed for stochastic decoding @@ -82,7 +82,8 @@ def main(cfg: DictConfig, **kwargs): num_shards=cfg.checkpoint.checkpoint_shard_count, ) - # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config + # loading the dataset should happen after the checkpoint has been loaded + # so we can give it the saved task config task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) if cfg.generation.lm_path is not None: @@ -104,10 +105,13 @@ def main(cfg: DictConfig, **kwargs): lms = [None] # Move models to GPU - for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)): + for model, ckpt_path in zip( + models, utils.split_paths( + cfg.common_eval.path)): if kwargs['ema_eval']: logger.info("loading EMA weights from {}".format(ckpt_path)) - model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model']) + model.load_state_dict( + checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model']) model.eval() if use_fp16: model.half() @@ -135,7 +139,8 @@ def main(cfg: DictConfig, **kwargs): itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, - default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + default_log_format=( + "tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator @@ -149,12 +154,15 @@ def main(cfg: DictConfig, **kwargs): if "net_input" not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample - sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample + sample = utils.apply_to_sample( + apply_half, sample) if cfg.common.fp16 else sample with torch.no_grad(): if kwargs["zero_shot"]: - result, scores = zero_shot_step(task, generator, models, sample) + result, scores = zero_shot_step( + task, generator, models, sample) else: - result, scores = eval_step(task, generator, models, sample, **kwargs) + result, scores = eval_step( + task, generator, models, sample, **kwargs) results += result if scores and isinstance(scores[0], tuple): score_sum += sum([s[0] for s in scores]) @@ -170,13 +178,23 @@ def main(cfg: DictConfig, **kwargs): def cli_main(): parser = options.get_generation_parser() - parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.") - parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.") + parser.add_argument( + "--ema-eval", + action='store_true', + help="Use EMA weights to make evaluation.") + parser.add_argument( + "--beam-search-vqa-eval", + action='store_true', + help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.") parser.add_argument("--zero-shot", action='store_true') args = options.parse_args_and_arch(parser) cfg = convert_namespace_to_omegaconf(args) distributed_utils.call_main( - cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot + cfg, + main, + ema_eval=args.ema_eval, + beam_search_vqa_eval=args.beam_search_vqa_eval, + zero_shot=args.zero_shot, ) diff --git a/models/ofa/unify_transformer.py b/models/ofa/unify_transformer.py index 3f2d04d2..a2eae7b2 100644 --- a/models/ofa/unify_transformer.py +++ b/models/ofa/unify_transformer.py @@ -1,6 +1,6 @@ -# Copyright 2022 The OFA-Sys Team. +# Copyright 2022 The OFA-Sys Team. # All rights reserved. -# This source code is licensed under the Apache 2.0 license +# This source code is licensed under the Apache 2.0 license # found in the LICENSE file in the root directory. import math @@ -35,6 +35,7 @@ from .unify_transformer_layer import TransformerEncoderLayer, TransformerDecoderLayer from .resnet import ResNet from .frozen_bn import FrozenBatchNorm2d +from .vit import VisionTransformer DEFAULT_MAX_SOURCE_POSITIONS = 1024 @@ -50,16 +51,25 @@ def BatchNorm2d(out_chan, momentum=0.1, eps=1e-3): ) -def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSITIONS): +def make_token_bucket_position(bucket_size, + max_position=DEFAULT_MAX_SOURCE_POSITIONS): context_pos = torch.arange(max_position, dtype=torch.long)[:, None] memory_pos = torch.arange(max_position, dtype=torch.long)[None, :] relative_pos = context_pos - memory_pos sign = torch.sign(relative_pos) mid = bucket_size // 2 - abs_pos = torch.where((relative_pos -mid), mid-1, torch.abs(relative_pos)) - log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid + abs_pos = torch.where( + (relative_pos < mid) & ( + relative_pos > -mid), + mid - 1, + torch.abs(relative_pos)) + log_pos = torch.ceil(torch.log(abs_pos / mid) / + math.log((max_position - 1) / mid) * (mid - 1)) + mid log_pos = log_pos.int() - bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos*sign).long() + bucket_pos = torch.where( + abs_pos.le(mid), + relative_pos, + log_pos * sign).long() return bucket_pos + bucket_size - 1 @@ -68,12 +78,18 @@ def make_image_bucket_position(bucket_size, num_relative_distance): coords_w = torch.arange(bucket_size) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords = coords_flatten[:, :, None] - \ + coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0 relative_coords[:, :, 1] += bucket_size - 1 relative_coords[:, :, 0] *= 2 * bucket_size - 1 - relative_position_index = torch.zeros(size=(bucket_size * bucket_size + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index = torch.zeros( + size=( + bucket_size * bucket_size + 1, + ) * 2, + dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 @@ -306,7 +322,8 @@ def add_args(parser): parser.add_argument('--interpolate-position', action='store_true', help='interpolate position') - parser.add_argument('--resnet-type', choices=['resnet50', 'resnet101', 'resnet152'], + parser.add_argument('--resnet-type', + choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'ViT-B/16'], help='resnet type') parser.add_argument('--resnet-model-path', type=str, metavar='STR', help='path to load resnet') @@ -354,7 +371,8 @@ def build_model(cls, args, task): if args.share_all_embeddings: if src_dict != tgt_dict: - raise ValueError("--share-all-embeddings requires a joined dictionary") + raise ValueError( + "--share-all-embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" @@ -377,11 +395,33 @@ def build_model(cls, args, task): decoder_embed_tokens = cls.build_embedding( args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) - if getattr(args, "freeze_encoder_embedding", False) or getattr( - args, "encoder_prompt", False) or getattr(args, "decoder_prompt", False) or getattr(args, "adapter", False): + if getattr( + args, + "freeze_encoder_embedding", + False) or getattr( + args, + "encoder_prompt", + False) or getattr( + args, + "decoder_prompt", + False) or getattr( + args, + "adapter", + False): encoder_embed_tokens.weight.requires_grad = False - if getattr(args, "freeze_decoder_embedding", False) or getattr( - args, "encoder_prompt", False) or getattr(args, "decoder_prompt", False) or getattr(args, "adapter", False): + if getattr( + args, + "freeze_decoder_embedding", + False) or getattr( + args, + "encoder_prompt", + False) or getattr( + args, + "decoder_prompt", + False) or getattr( + args, + "adapter", + False): decoder_embed_tokens.weight.requires_grad = False if getattr(args, "offload_activations", False): args.checkpoint_activations = True # offloading implies checkpointing @@ -399,7 +439,7 @@ def build_model(cls, args, task): for idx, layer in enumerate(encoder.layers): layer.adapter.requires_grad_(True) for idx, layer in enumerate(decoder.layers): - layer.adapter.requires_grad_(True) + layer.adapter.requires_grad_(True) if not args.share_all_embeddings: min_params_to_wrap = getattr( args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP @@ -454,8 +494,9 @@ def forward( which are not supported by TorchScript. """ encoder_out = self.encoder( - src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens - ) + src_tokens, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens) decoder_out = self.decoder( prev_output_tokens, encoder_out=encoder_out, @@ -478,7 +519,8 @@ def get_normalized_probs( sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" - return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + return self.get_normalized_probs_scriptable( + net_output, log_probs, sample) class TransformerEncoder(FairseqEncoder): @@ -496,7 +538,7 @@ def __init__(self, args, dictionary, embed_tokens): self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) - + if getattr(args, "encoder_prompt", False): self.encoder_prompt_encoder = PromptEncoder( type=args.encoder_prompt_type, @@ -507,7 +549,7 @@ def __init__(self, args, dictionary, embed_tokens): layers=args.encoder_layers, vocab_size=args.vocab_size) self.encoder_dropout = nn.Dropout(p=0.2) - + self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) @@ -520,7 +562,8 @@ def __init__(self, args, dictionary, embed_tokens): self.embed_tokens = embed_tokens - self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( + embed_dim) if getattr(args, "layernorm_embedding", False): self.layernorm_embedding = LayerNorm(embed_dim) @@ -540,15 +583,45 @@ def __init__(self, args, dictionary, embed_tokens): else: norm_layer = None - if args.resnet_type == 'resnet101': - self.embed_images = ResNet([3, 4, 23], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) - elif args.resnet_type == 'resnet152': - self.embed_images = ResNet([3, 8, 36], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) + if args.resnet_type == "resnet18": + self.embed_images = ResNet( + [2, 2, 2], + norm_layer=norm_layer, + drop_path_rate=args.resnet_drop_path_rate, + ) + elif args.resnet_type == "resnet34": + self.embed_images = ResNet( + [3, 4, 6], + norm_layer=norm_layer, + drop_path_rate=args.resnet_drop_path_rate, + ) elif args.resnet_type == 'resnet50': - self.embed_images = ResNet([3, 4, 6], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) + self.embed_images = ResNet( + [3, 4, 6], + norm_layer=norm_layer, + drop_path_rate=args.resnet_drop_path_rate + ) + elif args.resnet_type == 'resnet101': + self.embed_images = ResNet( + [3, 4, 23], + norm_layer=norm_layer, + drop_path_rate=args.resnet_drop_path_rate + ) + elif args.resnet_type == 'resnet152': + self.embed_images = ResNet( + [3, 8, 36], + norm_layer=norm_layer, + drop_path_rate=args.resnet_drop_path_rate + ) + elif args.resnet_type == "ViT-B/16": + self.embed_images = VisionTransformer(224, 16, 768, 9, 12) else: raise NotImplementedError - self.image_proj = Linear(1024, embed_dim) + + if args.resnet_type == "ViT-B/16": + self.image_proj = Linear(768, embed_dim) + else: + self.image_proj = Linear(1024, embed_dim) if getattr(args, "resnet_model_path", None): print("load resnet {}".format(args.resnet_model_path)) resnet_state_dict = torch.load(self.args.resnet_model_path) @@ -558,11 +631,14 @@ def __init__(self, args, dictionary, embed_tokens): else: self.patch_layernorm_embedding = None - self.embed_positions = Embedding(args.max_source_positions + 2, embed_dim) - self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim) + self.embed_positions = Embedding( + args.max_source_positions + 2, embed_dim) + self.embed_image_positions = Embedding( + args.image_bucket_size ** 2 + 1, embed_dim) self.pos_ln = LayerNorm(embed_dim) self.image_pos_ln = LayerNorm(embed_dim) - self.pos_scaling = float(embed_dim / args.encoder_attention_heads * args.attn_scale_factor) ** -0.5 + self.pos_scaling = float( + embed_dim / args.encoder_attention_heads * args.attn_scale_factor) ** -0.5 self.pos_q_linear = nn.Linear(embed_dim, embed_dim) self.pos_k_linear = nn.Linear(embed_dim, embed_dim) @@ -580,10 +656,13 @@ def __init__(self, args, dictionary, embed_tokens): else: self.layers = nn.ModuleList([]) - dpr = [x.item() for x in torch.linspace(0, args.encoder_drop_path_rate, args.encoder_layers)] - self.layers.extend( - [self.build_encoder_layer(args, drop_path_rate=dpr[i]) for i in range(args.encoder_layers)] - ) + dpr = [ + x.item() for x in torch.linspace( + 0, + args.encoder_drop_path_rate, + args.encoder_layers)] + self.layers.extend([self.build_encoder_layer( + args, drop_path_rate=dpr[i]) for i in range(args.encoder_layers)]) self.num_layers = len(self.layers) if args.encoder_normalize_before: @@ -595,15 +674,25 @@ def __init__(self, args, dictionary, embed_tokens): token_num_rel_dis = 2 * token_bucket_size - 1 token_rp_bucket = make_token_bucket_position(token_bucket_size) self.token_rel_pos_table_list = nn.ModuleList( - [Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)] - ) + [ + Embedding( + token_num_rel_dis, + self.num_attention_heads, + zero_init=True) for _ in range( + args.encoder_layers)]) image_bucket_size = args.image_bucket_size - image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3 - image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis) + image_num_rel_dis = (2 * image_bucket_size - 1) * \ + (2 * image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position( + image_bucket_size, image_num_rel_dis) self.image_rel_pos_table_list = nn.ModuleList( - [Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)] - ) + [ + Embedding( + image_num_rel_dis, + self.num_attention_heads, + zero_init=True) for _ in range( + args.encoder_layers)]) self.patch_image_size = args.patch_image_size self.orig_patch_image_size = args.orig_patch_image_size @@ -613,8 +702,10 @@ def __init__(self, args, dictionary, embed_tokens): self.entangle_position_embedding = args.entangle_position_embedding def build_encoder_layer(self, args, drop_path_rate=0.0): - layer = TransformerEncoderLayer(args, drop_path_rate=drop_path_rate, \ - use_adapter=getattr(args, "adapter", False), adapter_dim=getattr(args, "adapter_dim", 200)) + layer = TransformerEncoderLayer( + args, drop_path_rate=drop_path_rate, use_adapter=getattr( + args, "adapter", False), adapter_dim=getattr( + args, "adapter_dim", 200)) checkpoint = getattr(args, "checkpoint_activations", False) if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) @@ -631,7 +722,9 @@ def build_encoder_layer(self, args, drop_path_rate=0.0): def get_rel_pos_bias(self, x, idx): seq_len = x.size(1) rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] - values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight) + values = F.embedding( + rp_bucket, + self.token_rel_pos_table_list[idx].weight) values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) values = values.permute([0, 3, 1, 2]) return values.contiguous() @@ -643,8 +736,10 @@ def get_image_rel_pos_bias(self, image_position_ids, idx): rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( bsz, rp_bucket_size, rp_bucket_size ).gather(1, image_position_ids[:, :, None].expand(bsz, seq_len, rp_bucket_size) - ).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len)) - values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight) + ).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len)) + values = F.embedding( + rp_bucket, + self.image_rel_pos_table_list[idx].weight) values = values.permute(0, 3, 1, 2) return values @@ -652,11 +747,13 @@ def get_patch_images_info(self, patch_images, sample_patch_num, device): image_embed = self.embed_images(patch_images) h, w = image_embed.shape[-2:] image_num_patches = h * w - image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool() + image_padding_mask = patch_images.new_zeros( + (patch_images.size(0), image_num_patches)).bool() image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \ - torch.arange(h).unsqueeze(1) * self.args.image_bucket_size + 1 + torch.arange(h).unsqueeze(1) * self.args.image_bucket_size + 1 image_position_idx = image_position_idx.view(-1).to(device) - image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches) + image_position_ids = image_position_idx[None, :].expand( + patch_images.size(0), image_num_patches) image_embed = image_embed.flatten(2).transpose(1, 2) if sample_patch_num is not None: @@ -665,23 +762,32 @@ def get_patch_images_info(self, patch_images, sample_patch_num, device): for _ in range(patch_images.size(0)) ] patch_orders = torch.LongTensor(patch_orders).to(device) - image_embed = image_embed.gather( - 1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2)) - ) + image_embed = image_embed.gather(1, patch_orders.unsqueeze( + 2).expand(-1, -1, image_embed.size(2))) image_num_patches = sample_patch_num image_padding_mask = image_padding_mask.gather(1, patch_orders) image_position_ids = image_position_ids.gather(1, patch_orders) orig_num_patches = (self.orig_patch_image_size // 16) ** 2 - orig_hw= self.orig_patch_image_size // 16 - if getattr(self.args, "interpolate_position", False) and image_num_patches > orig_num_patches: - old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand(orig_hw, orig_hw) + \ - torch.arange(orig_hw).unsqueeze(1) * self.args.image_bucket_size + 1 + orig_hw = self.orig_patch_image_size // 16 + if getattr( + self.args, + "interpolate_position", + False) and image_num_patches > orig_num_patches: + old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand( + orig_hw, orig_hw) + torch.arange(orig_hw).unsqueeze(1) * self.args.image_bucket_size + 1 old_image_position_ids = old_image_position_ids.to(device) - old_image_pos_embed = self.embed_image_positions(old_image_position_ids) - old_image_pos_embed = old_image_pos_embed.reshape(1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2) - image_pos_embed = F.interpolate(old_image_pos_embed, size=(h, w), mode='bilinear') - image_pos_embed = image_pos_embed.permute(0, 2, 3, 1).reshape(1, image_num_patches, -1) - image_pos_embed = image_pos_embed.expand(patch_images.size(0), -1, -1) + old_image_pos_embed = self.embed_image_positions( + old_image_position_ids) + old_image_pos_embed = old_image_pos_embed.reshape( + 1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2) + image_pos_embed = F.interpolate( + old_image_pos_embed, size=( + h, w), mode='bilinear') + image_pos_embed = image_pos_embed.permute( + 0, 2, 3, 1).reshape( + 1, image_num_patches, -1) + image_pos_embed = image_pos_embed.expand( + patch_images.size(0), -1, -1) else: image_pos_embed = self.embed_image_positions(image_position_ids) @@ -700,7 +806,7 @@ def get_encoder_prompt(self, prompt_tokens): past_key_values = self.encoder_dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) return past_key_values - + def forward_embedding( self, src_tokens, @@ -732,7 +838,8 @@ def forward_embedding( if self.entangle_position_embedding and image_pos_embed is not None: image_x += image_pos_embed if self.type_embedding is not None: - image_x += self.type_embedding(src_tokens.new_ones(image_x.size()[:2])) + image_x += self.type_embedding( + src_tokens.new_ones(image_x.size()[:2])) if self.patch_layernorm_embedding is not None: image_x = self.patch_layernorm_embedding(image_x) image_x = self.dropout_module(image_x) @@ -748,7 +855,10 @@ def forward_embedding( if self.entangle_position_embedding and image_pos_embed_2 is not None: image_x_2 += image_pos_embed_2 if self.type_embedding is not None: - image_x_2 += self.type_embedding(src_tokens.new_full(image_x_2.size()[:2], fill_value=2)) + image_x_2 += self.type_embedding( + src_tokens.new_full( + image_x_2.size()[ + :2], fill_value=2)) if self.patch_layernorm_embedding is not None: image_x_2 = self.patch_layernorm_embedding(image_x_2) image_x_2 = self.dropout_module(image_x_2) @@ -851,7 +961,8 @@ def forward_scriptable( 0, self.args.encoder_prompt_length).to( src_tokens.device) prompt_tokens = prompt_tokens.unsqueeze(0).expand(bsz, -1) - prompt_padding_mask = torch.zeros_like(prompt_tokens).to(prompt_tokens.device) + prompt_padding_mask = torch.zeros_like( + prompt_tokens).to(prompt_tokens.device) prompt_kv_list = self.get_encoder_prompt(prompt_tokens) image_embed = None image_embed_2 = None @@ -868,10 +979,13 @@ def forward_scriptable( encoder_padding_mask = src_tokens.eq(self.padding_idx) if patch_images is not None: - encoder_padding_mask = torch.cat([image_padding_mask, encoder_padding_mask], dim=1) + encoder_padding_mask = torch.cat( + [image_padding_mask, encoder_padding_mask], dim=1) if patch_images_2 is not None: - encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1) - has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any()) + encoder_padding_mask = torch.cat( + [image_padding_mask_2, encoder_padding_mask], dim=1) + has_pads = (src_tokens.device.type == + "xla" or encoder_padding_mask.any()) pos_embed = self.embed_positions(utils.new_arange(src_tokens)) x, encoder_embedding = self.forward_embedding( @@ -908,20 +1022,25 @@ def forward_scriptable( encoder_states.append(x) if prompt_padding_mask is not None: - encoder_padding_mask = torch.cat([prompt_padding_mask, encoder_padding_mask], dim=1) + encoder_padding_mask = torch.cat( + [prompt_padding_mask, encoder_padding_mask], dim=1) # encoder layers for idx, layer in enumerate(self.layers): self_attn_bias = abs_pos_bias.clone() - self_attn_bias[:, :, -src_tokens.size(1):, -src_tokens.size(1):] += self.get_rel_pos_bias(src_tokens, idx) + self_attn_bias[:, :, - + src_tokens.size(1):, - + src_tokens.size(1):] += self.get_rel_pos_bias(src_tokens, idx) if patch_images_2 is not None: self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ self.get_image_rel_pos_bias(image_position_ids_2, idx) - self_attn_bias[:, :, image_num_patches_2:image_num_patches_2+image_num_patches, image_num_patches_2:image_num_patches_2+image_num_patches] += \ - self.get_image_rel_pos_bias(image_position_ids, idx) + self_attn_bias[:, :, image_num_patches_2:image_num_patches_2 + + image_num_patches, image_num_patches_2:image_num_patches_2 + + image_num_patches] += self.get_image_rel_pos_bias(image_position_ids, idx) elif patch_images is not None: - self_attn_bias[:, :, :x.size(0) - src_tokens.size(1), :x.size(0) - src_tokens.size(1)] += \ - self.get_image_rel_pos_bias(image_position_ids, idx) - self_attn_bias = self_attn_bias.reshape(-1, self_attn_bias.size(2), self_attn_bias.size(2)) + self_attn_bias[:, :, :x.size(0) - src_tokens.size(1), :x.size( + 0) - src_tokens.size(1)] += self.get_image_rel_pos_bias(image_position_ids, idx) + self_attn_bias = self_attn_bias.reshape( + -1, self_attn_bias.size(2), self_attn_bias.size(2)) if self.args.encoder_prompt: if self.args.encoder_prompt_type != "prompt": prompt_kv = prompt_kv_list[idx] @@ -931,9 +1050,12 @@ def forward_scriptable( else: prompt_kv = None else: - prompt_kv = None - x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None, \ - self_attn_bias=self_attn_bias, prompt_kv=prompt_kv) + prompt_kv = None + x = layer( + x, + encoder_padding_mask=encoder_padding_mask if has_pads else None, + self_attn_bias=self_attn_bias, + prompt_kv=prompt_kv) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) @@ -941,7 +1063,8 @@ def forward_scriptable( if self.layer_norm is not None: x = self.layer_norm(x) if self.args.encoder_prompt: - encoder_padding_mask = encoder_padding_mask[:, prompt_tokens.size(1):] + encoder_padding_mask = encoder_padding_mask[:, prompt_tokens.size( + 1):] # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. @@ -957,7 +1080,8 @@ def forward_scriptable( } @torch.jit.export - def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + def reorder_encoder_out( + self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. @@ -971,13 +1095,15 @@ def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: - new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + new_encoder_out = [ + encoder_out["encoder_out"][0].index_select( + 1, new_order)] if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ - encoder_out["encoder_padding_mask"][0].index_select(0, new_order) - ] + encoder_out["encoder_padding_mask"][0].index_select( + 0, new_order)] if len(encoder_out["encoder_embedding"]) == 0: new_encoder_embedding = [] else: @@ -988,17 +1114,23 @@ def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): if len(encoder_out["src_tokens"]) == 0: new_src_tokens = [] else: - new_src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] + new_src_tokens = [ + (encoder_out["src_tokens"][0]).index_select( + 0, new_order)] if len(encoder_out["src_lengths"]) == 0: new_src_lengths = [] else: - new_src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] + new_src_lengths = [ + (encoder_out["src_lengths"][0]).index_select( + 0, new_order)] if len(encoder_out["position_embeddings"]) == 0: new_position_embeddings = [] else: - new_position_embeddings = [(encoder_out["position_embeddings"][0]).index_select(0, new_order)] + new_position_embeddings = [ + (encoder_out["position_embeddings"][0]).index_select( + 0, new_order)] encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: @@ -1049,11 +1181,18 @@ def upgrade_state_dict_named(self, state_dict, name): if (prefix + param_name) not in state_dict: state_dict[prefix + param_name] = self.state_dict()[param_name] - if len(state_dict["encoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]): - num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["encoder.embed_image_positions.weight"]) - embed_dim = state_dict["encoder.embed_image_positions.weight"].size(1) + if len(state_dict["encoder.embed_image_positions.weight"]) < len( + self.state_dict()["embed_image_positions.weight"]): + num_posids_to_add = len( + self.state_dict()["embed_image_positions.weight"]) - len( + state_dict["encoder.embed_image_positions.weight"]) + embed_dim = state_dict["encoder.embed_image_positions.weight"].size( + 1) new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim) - nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5) + nn.init.normal_( + new_pos_embed_to_add, + mean=0, + std=embed_dim ** -0.5) new_pos_embed_to_add = new_pos_embed_to_add.to( dtype=state_dict["encoder.embed_image_positions.weight"].dtype, ) @@ -1117,7 +1256,8 @@ def __init__( self.embed_tokens = embed_tokens - self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt( + embed_dim) if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( @@ -1141,11 +1281,14 @@ def __init__( self.window_size = args.code_image_size // 8 - self.embed_positions = Embedding(args.max_target_positions + 2, embed_dim) - self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim) + self.embed_positions = Embedding( + args.max_target_positions + 2, embed_dim) + self.embed_image_positions = Embedding( + args.image_bucket_size ** 2 + 1, embed_dim) self.pos_ln = LayerNorm(embed_dim) self.image_pos_ln = LayerNorm(embed_dim) - self.pos_scaling = float(embed_dim / self.num_attention_heads * args.attn_scale_factor) ** -0.5 + self.pos_scaling = float( + embed_dim / self.num_attention_heads * args.attn_scale_factor) ** -0.5 self.self_pos_q_linear = nn.Linear(embed_dim, embed_dim) self.self_pos_k_linear = nn.Linear(embed_dim, embed_dim) self.cross_pos_q_linear = nn.Linear(embed_dim, embed_dim) @@ -1156,14 +1299,19 @@ def __init__( else: self.code_layernorm_embedding = None - self.cross_self_attention = getattr(args, "cross_self_attention", False) + self.cross_self_attention = getattr( + args, "cross_self_attention", False) if self.decoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) - dpr = [x.item() for x in torch.linspace(0, args.decoder_drop_path_rate, args.decoder_layers)] + dpr = [ + x.item() for x in torch.linspace( + 0, + args.decoder_drop_path_rate, + args.decoder_layers)] self.layers.extend( [ self.build_decoder_layer(args, no_encoder_attn, drop_path_rate=dpr[i]) @@ -1192,19 +1340,33 @@ def __init__( token_num_rel_dis = 2 * token_bucket_size - 1 token_rp_bucket = make_token_bucket_position(token_bucket_size) self.token_rel_pos_table_list = nn.ModuleList( - [Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)] - ) + [ + Embedding( + token_num_rel_dis, + self.num_attention_heads, + zero_init=True) for _ in range( + args.decoder_layers)]) image_bucket_size = args.image_bucket_size - image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3 - image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis) - image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ - torch.arange(self.window_size).unsqueeze(1) * image_bucket_size + 1 - image_position_idx = torch.cat([torch.tensor([0]), image_position_idx.view(-1)]) - image_position_idx = torch.cat([image_position_idx, torch.tensor([1024] * 769)]) + image_num_rel_dis = (2 * image_bucket_size - 1) * \ + (2 * image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position( + image_bucket_size, image_num_rel_dis) + image_position_idx = torch.arange( + self.window_size).unsqueeze(0).expand( + self.window_size, self.window_size) + torch.arange( + self.window_size).unsqueeze(1) * image_bucket_size + 1 + image_position_idx = torch.cat( + [torch.tensor([0]), image_position_idx.view(-1)]) + image_position_idx = torch.cat( + [image_position_idx, torch.tensor([1024] * 769)]) self.image_rel_pos_table_list = nn.ModuleList( - [Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)] - ) + [ + Embedding( + image_num_rel_dis, + self.num_attention_heads, + zero_init=True) for _ in range( + args.decoder_layers)]) self.register_buffer("token_rp_bucket", token_rp_bucket) self.register_buffer("image_rp_bucket", image_rp_bucket) @@ -1230,7 +1392,9 @@ def build_output_projection(self, args, dictionary, embed_tokens): self.adaptive_softmax = AdaptiveSoftmax( len(dictionary), self.output_embed_dim, - utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), + utils.eval_str_list( + args.adaptive_softmax_cutoff, + type=int), dropout=args.adaptive_softmax_dropout, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, factor=args.adaptive_softmax_factor, @@ -1248,15 +1412,23 @@ def build_output_projection(self, args, dictionary, embed_tokens): self.output_embed_dim, len(dictionary), bias=False ) nn.init.normal_( - self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 - ) + self.output_projection.weight, + mean=0, + std=self.output_embed_dim ** -0.5) num_base_layers = getattr(args, "base_layers", 0) for i in range(num_base_layers): - self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args)) + self.layers.insert(((i + 1) * args.decoder_layers) // + (num_base_layers + 1), BaseLayer(args)) - def build_decoder_layer(self, args, no_encoder_attn=False, drop_path_rate=0.0): - layer = TransformerDecoderLayer(args, no_encoder_attn, drop_path_rate= \ - drop_path_rate, use_adapter=getattr(args, "adapter", False), adapter_dim=getattr(args, "adapter_dim", 200)) + def build_decoder_layer( + self, + args, + no_encoder_attn=False, + drop_path_rate=0.0): + layer = TransformerDecoderLayer( + args, no_encoder_attn, drop_path_rate=drop_path_rate, use_adapter=getattr( + args, "adapter", False), adapter_dim=getattr( + args, "adapter_dim", 200)) checkpoint = getattr(args, "checkpoint_activations", False) if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) @@ -1273,22 +1445,33 @@ def build_decoder_layer(self, args, no_encoder_attn=False, drop_path_rate=0.0): def get_rel_pos_bias(self, x, idx): seq_len = x.size(1) rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] - values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight) + values = F.embedding( + rp_bucket, + self.token_rel_pos_table_list[idx].weight) values = values.permute([2, 0, 1]) return values.contiguous() def get_image_rel_pos_bias(self, x, idx): seq_len = x.size(1) image_position_idx = self.image_position_idx[:seq_len] - rp_bucket = self.image_rp_bucket[image_position_idx][:, image_position_idx] - values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight) + rp_bucket = self.image_rp_bucket[image_position_idx][:, + image_position_idx] + values = F.embedding( + rp_bucket, + self.image_rel_pos_table_list[idx].weight) values = values.permute(2, 0, 1) return values - def get_pos_info(self, tokens, tgt_pos_embed, src_pos_embed=None, use_image=False): + def get_pos_info( + self, + tokens, + tgt_pos_embed, + src_pos_embed=None, + use_image=False): batch_size = tokens.size(0) tgt_len = tokens.size(1) - tgt_pos_embed = self.image_pos_ln(tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) + tgt_pos_embed = self.image_pos_ln( + tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) if src_pos_embed is not None: src_len = src_pos_embed.size(1) pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( @@ -1419,7 +1602,8 @@ def extract_features_scriptable( 0, self.args.decoder_prompt_length).to( prev_output_tokens.device) prompt_tokens = prompt_tokens.unsqueeze(0).expand(bsz, -1) - prompt_padding_mask = torch.zeros_like(prompt_tokens).to(prompt_tokens.device) + prompt_padding_mask = torch.zeros_like( + prompt_tokens).to(prompt_tokens.device) prompt_kv_list = self.get_decoder_prompt(prompt_tokens) bs, slen = prev_output_tokens.size() if alignment_layer is None: @@ -1432,28 +1616,36 @@ def extract_features_scriptable( assert ( enc.size()[1] == bs ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" - if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: + if encoder_out is not None and len( + encoder_out["encoder_padding_mask"]) > 0: padding_mask = encoder_out["encoder_padding_mask"][0] bsz, tgt_len = prev_output_tokens.shape token_position_idx = utils.new_arange(prev_output_tokens) tgt_pos_embed = self.embed_positions(token_position_idx) if code_masks is not None and torch.any(code_masks): - image_position_idx = self.image_position_idx[:prev_output_tokens.size(1)].unsqueeze(0).expand(bsz, tgt_len) - tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks] + image_position_idx = self.image_position_idx[:prev_output_tokens.size( + 1)].unsqueeze(0).expand(bsz, tgt_len) + tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[ + code_masks] # self attn position bias - self_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=False) + self_abs_pos_bias = self.get_pos_info( + prev_output_tokens, tgt_pos_embed, use_image=False) if code_masks is not None and torch.any(code_masks): - self_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=True) + self_image_abs_pos_bias = self.get_pos_info( + prev_output_tokens, tgt_pos_embed, use_image=True) self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] # cross attn position bias src_pos_embed = encoder_out['position_embeddings'][0] - cross_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed) + cross_abs_pos_bias = self.get_pos_info( + prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed) if code_masks is not None and torch.any(code_masks): - cross_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True) + cross_image_abs_pos_bias = self.get_pos_info( + prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True) cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks] - cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:]) + cross_abs_pos_bias = cross_abs_pos_bias.reshape( + -1, *cross_abs_pos_bias.size()[-2:]) all_prev_output_tokens = prev_output_tokens.clone() if incremental_state is not None: @@ -1474,7 +1666,8 @@ def extract_features_scriptable( x += tgt_pos_embed if self.layernorm_embedding is not None: - if code_masks is None or not code_masks.any() or not getattr(self, "code_layernorm_embedding", False): + if code_masks is None or not code_masks.any() or not getattr( + self, "code_layernorm_embedding", False): x = self.layernorm_embedding(x) elif code_masks is not None and code_masks.all(): x = self.code_layernorm_embedding(x) @@ -1488,10 +1681,12 @@ def extract_features_scriptable( x = x.transpose(0, 1) self_attn_padding_mask: Optional[Tensor] = None - if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + if self.cross_self_attention or prev_output_tokens.eq( + self.padding_idx).any(): self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) if prompt_padding_mask is not None: - self_attn_padding_mask = torch.cat([prompt_padding_mask, self_attn_padding_mask], dim=1) + self_attn_padding_mask = torch.cat( + [prompt_padding_mask, self_attn_padding_mask], dim=1) # decoder layers attn: Optional[Tensor] = None @@ -1501,20 +1696,27 @@ def extract_features_scriptable( self_attn_mask = self.buffered_future_mask(x) if self.args.decoder_prompt: seq_len, prompt_len = x.size(0), prompt_tokens.size(1) - prompt_mask = torch.zeros([seq_len, prompt_len]).to(x.device) - self_attn_mask = torch.cat([prompt_mask, self_attn_mask], dim=1) + prompt_mask = torch.zeros( + [seq_len, prompt_len]).to(x.device) + self_attn_mask = torch.cat( + [prompt_mask, self_attn_mask], dim=1) else: self_attn_mask = None self_attn_bias = self_abs_pos_bias.clone() if code_masks is None or not code_masks.any(): - self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) + self_attn_bias += self.get_rel_pos_bias( + all_prev_output_tokens, idx).unsqueeze(0) elif code_masks is not None and code_masks.all(): - self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) + self_attn_bias += self.get_image_rel_pos_bias( + all_prev_output_tokens, idx).unsqueeze(0) else: - self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) - self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) - self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:]) + self_attn_bias[~code_masks] += self.get_rel_pos_bias( + all_prev_output_tokens, idx).unsqueeze(0) + self_attn_bias[code_masks] += self.get_image_rel_pos_bias( + all_prev_output_tokens, idx).unsqueeze(0) + self_attn_bias = self_attn_bias.reshape( + -1, *self_attn_bias.size()[-2:]) if incremental_state is not None: self_attn_bias = self_attn_bias[:, -1:, :] @@ -1580,7 +1782,8 @@ def max_positions(self): def buffered_future_mask(self, tensor): dim = tensor.size(0) - # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. + # self._future_mask.device != tensor.device is not working in + # TorchScript. This is a workaround. if ( self._future_mask.size(0) == 0 or (not self._future_mask.device == tensor.device) @@ -1635,11 +1838,18 @@ def upgrade_state_dict_named(self, state_dict, name): if (prefix + param_name) not in state_dict: state_dict[prefix + param_name] = self.state_dict()[param_name] - if len(state_dict["decoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]): - num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["decoder.embed_image_positions.weight"]) - embed_dim = state_dict["decoder.embed_image_positions.weight"].size(1) + if len(state_dict["decoder.embed_image_positions.weight"]) < len( + self.state_dict()["embed_image_positions.weight"]): + num_posids_to_add = len( + self.state_dict()["embed_image_positions.weight"]) - len( + state_dict["decoder.embed_image_positions.weight"]) + embed_dim = state_dict["decoder.embed_image_positions.weight"].size( + 1) new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim) - nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5) + nn.init.normal_( + new_pos_embed_to_add, + mean=0, + std=embed_dim ** -0.5) new_pos_embed_to_add = new_pos_embed_to_add.to( dtype=state_dict["decoder.embed_image_positions.weight"].dtype, ) @@ -1649,7 +1859,11 @@ def upgrade_state_dict_named(self, state_dict, name): return state_dict -def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False): +def Embedding( + num_embeddings, + embedding_dim, + padding_idx=None, + zero_init=False): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) if padding_idx is not None: @@ -1674,23 +1888,28 @@ def base_architecture(args): args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) - args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_normalize_before = getattr( + args, "encoder_normalize_before", False) args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) args.decoder_embed_path = getattr(args, "decoder_embed_path", None) - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_embed_dim = getattr( + args, "decoder_embed_dim", args.encoder_embed_dim) args.decoder_ffn_embed_dim = getattr( args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim ) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) - args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_normalize_before = getattr( + args, "decoder_normalize_before", False) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.activation_fn = getattr(args, "activation_fn", "relu") args.dropout = getattr(args, "dropout", 0.1) - args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) - args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr( + args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) @@ -1705,24 +1924,30 @@ def base_architecture(args): args.decoder_output_dim = getattr( args, "decoder_output_dim", args.decoder_embed_dim ) - args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.decoder_input_dim = getattr( + args, "decoder_input_dim", args.decoder_embed_dim) args.encoder_prompt = getattr(args, "encoder_prompt", False) args.encoder_prompt_length = getattr(args, "encoder_prompt_length", 100) args.encoder_prompt_type = getattr(args, "encoder_prompt_type", "prefix") - args.encoder_prompt_projection = getattr(args, "encoder_prompt_projection", False) - args.encoder_prompt_dim = getattr(args, "encoder_prompt_dim", 2 * args.encoder_embed_dim) + args.encoder_prompt_projection = getattr( + args, "encoder_prompt_projection", False) + args.encoder_prompt_dim = getattr( + args, "encoder_prompt_dim", 2 * args.encoder_embed_dim) args.decoder_prompt = getattr(args, "decoder_prompt", False) args.decoder_prompt_length = getattr(args, "decoder_prompt_length", 100) args.decoder_prompt_type = getattr(args, "decoder_prompt_type", "prefix") - args.decoder_prompt_projection = getattr(args, "decoder_prompt_projection", False) - args.decoder_prompt_dim = getattr(args, "decoder_prompt_dim", 2 * args.encoder_embed_dim) + args.decoder_prompt_projection = getattr( + args, "decoder_prompt_projection", False) + args.decoder_prompt_dim = getattr( + args, "decoder_prompt_dim", 2 * args.encoder_embed_dim) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) - args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + args.checkpoint_activations = getattr( + args, "checkpoint_activations", False) args.offload_activations = getattr(args, "offload_activations", False) if args.offload_activations: args.checkpoint_activations = True @@ -1731,5 +1956,6 @@ def base_architecture(args): args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) - args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_pq_block_size = getattr( + args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) diff --git a/models/ofa/vit.py b/models/ofa/vit.py new file mode 100644 index 00000000..db2cd021 --- /dev/null +++ b/models/ofa/vit.py @@ -0,0 +1,113 @@ +from collections import OrderedDict +import torch +import torch.nn.functional as F +from torch import nn +from fairseq.modules import LayerNorm + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = ( + self.attn_mask.to(dtype=x.dtype, device=x.device) + if self.attn_mask is not None + else None + ) + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential( + *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] + ) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__( + self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + ): + super().__init__() + self.input_resolution = input_resolution + self.patch_size = patch_size + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + # self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) + ) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + def forward(self, x: torch.Tensor): + resolution = x.shape[-2] + height, width = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + if resolution != 224: + old_pe = self.positional_embedding[1:] + old_pe = old_pe.reshape(1, 14, 14, -1).permute(0, 3, 1, 2) + new_pe = F.interpolate(old_pe, size=(height, width), mode="bilinear") + new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1) + x = x + new_pe.to(x.dtype) + else: + x = x + self.positional_embedding[1:].to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + bz, seq, hidden = x.shape + x = x.transpose(1, 2).reshape(bz, hidden, height, width) + + return x diff --git a/run_scripts/ocr/evaluate_ocr.sh b/run_scripts/ocr/evaluate_ocr.sh new file mode 100644 index 00000000..acba785f --- /dev/null +++ b/run_scripts/ocr/evaluate_ocr.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +# The port for communication. Note that if you want to run multiple tasks on the same machine, +# you need to specify different port numbers. +export MASTER_PORT=1091 + +user_dir=../../ofa_module +bpe_dir=../../utils/BPE + +data=../../dataset/caption_data/ocr_scene_test.tsv +path=../../checkpoints/ofa_cn_ocr_large.pt +result_path=../../results/ocr +selected_cols=0,1,2 +split='test' + +CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=${MASTER_PORT} ../../evaluate.py \ + ${data} \ + --path=${path} \ + --user-dir=${user_dir} \ + --task=ocr \ + --batch-size=8 \ + --log-format=simple --log-interval=10 \ + --seed=7 \ + --gen-subset=${split} \ + --results-path=${result_path} \ + --beam=5 \ + --max-len-b=64 \ + --no-repeat-ngram-size=0 \ + --fp16 \ + --num-workers=0 \ + --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\",\"resnet_model_path\":None}" \ No newline at end of file diff --git a/run_scripts/ocr/evaluate_ocr_base.sh b/run_scripts/ocr/evaluate_ocr_base.sh new file mode 100644 index 00000000..5c653d56 --- /dev/null +++ b/run_scripts/ocr/evaluate_ocr_base.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +# The port for communication. Note that if you want to run multiple tasks on the same machine, +# you need to specify different port numbers. +export MASTER_PORT=1091 + +user_dir=../../ofa_module +bpe_dir=../../utils/BPE + +data=../../dataset/caption_data/ocr_scene_test.tsv +path=../../checkpoints/ofa_cn_ocr_base.pt +result_path=../../results/ocr +selected_cols=0,1,2 +split='test' + +CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=${MASTER_PORT} ../../evaluate.py \ + ${data} \ + --path=${path} \ + --user-dir=${user_dir} \ + --task=ocr \ + --batch-size=16 \ + --log-format=simple --log-interval=10 \ + --seed=7 \ + --gen-subset=${split} \ + --results-path=${result_path} \ + --beam=5 \ + --max-len-b=64 \ + --no-repeat-ngram-size=0 \ + --fp16 \ + --num-workers=0 \ + --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\",\"resnet_model_path\":None}" \ No newline at end of file