diff --git a/data/mm_data/ocr_dataset.py b/data/mm_data/ocr_dataset.py new file mode 100644 index 00000000..cef176c9 --- /dev/null +++ b/data/mm_data/ocr_dataset.py @@ -0,0 +1,204 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. + +from io import BytesIO + +import logging +import warnings +import random +import functools + +import torch +import base64 +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms import functional as F + +from PIL import Image, ImageFile + +from zhconv import convert +import unicodedata + +from data import data_utils +from data.ofa_dataset import OFADataset + +ImageFile.LOAD_TRUNCATED_IMAGES = True +ImageFile.MAX_IMAGE_PIXELS = None +Image.MAX_IMAGE_PIXELS = None + +logger = logging.getLogger(__name__) +warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + + +def collate(samples, pad_idx, eos_idx): + if len(samples) == 0: + return {} + + def merge(key): + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx=eos_idx, + ) + + id = np.array([s["id"] for s in samples]) + src_tokens = merge("source") + src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples]) + + patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0) + patch_masks = torch.cat([sample['patch_mask'] for sample in samples]) + + prev_output_tokens = None + target = None + if samples[0].get("target", None) is not None: + target = merge("target") + tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples]) + ntokens = tgt_lengths.sum().item() + + if samples[0].get("prev_output_tokens", None) is not None: + prev_output_tokens = merge("prev_output_tokens") + else: + ntokens = src_lengths.sum().item() + + batch = { + "id": id, + "nsentences": len(samples), + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + "patch_images": patch_images, + "patch_masks": patch_masks, + "prev_output_tokens": prev_output_tokens + }, + "target": target, + } + + return batch + + +def ocr_resize(img, patch_image_size, is_document=False): + img = img.convert("RGB") + width, height = img.size + + if is_document: + new_height, new_width = 64, 1920 + else: + if width >= height: + new_width = max(64, patch_image_size) + new_height = max(64, int(patch_image_size * (height / width))) + top = random.randint(0, patch_image_size - new_height) + bottom = patch_image_size - new_height - top + left, right = 0, 0 + else: + new_height = max(64, patch_image_size) + new_width = max(64, int(patch_image_size * (width / height))) + left = random.randint(0, patch_image_size - new_width) + right = patch_image_size - new_width - left + top, bottom = 0, 0 + + img_new = F.resize( + img, + [new_height, new_width], + interpolation=InterpolationMode.BICUBIC, + ) + + if is_document: + img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1) + img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2)) + new_width, new_height = img_new.size + top = random.randint(0, patch_image_size - new_height) + bottom = patch_image_size - new_height - top + left, right = 0, 0 + + img_new = F.pad(img_new, padding=[left, top, right, bottom], padding_mode="edge") + assert img_new.size == (patch_image_size, patch_image_size) + + return img_new + + +class OcrDataset(OFADataset): + def __init__( + self, + split, + dataset, + bpe, + src_dict, + tgt_dict=None, + max_src_length=80, + max_tgt_length=30, + patch_image_size=224, + imagenet_default_mean_and_std=False, + is_document=False, + ): + super().__init__(split, dataset, bpe, src_dict, tgt_dict) + self.max_src_length = max_src_length + self.max_tgt_length = max_tgt_length + self.patch_image_size = patch_image_size + + if imagenet_default_mean_and_std: + mean = IMAGENET_DEFAULT_MEAN + std = IMAGENET_DEFAULT_STD + else: + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + + self.patch_resize_transform = transforms.Compose( + [ + lambda image: ocr_resize( + image, patch_image_size, is_document=is_document + ), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ) + + self.bpe = bpe + if type(bpe).__name__ == 'GPT2BPE': + self.prompt = " what are the texts on the image?" + elif type(bpe).__name__ == 'BertBPE': + self.prompt = "图片上的文字是什么?" + + def __getitem__(self, index): + uniq_id, image, caption = self.dataset[index] + + image = Image.open(BytesIO(base64.urlsafe_b64decode(image))) + patch_image = self.patch_resize_transform(image) + patch_mask = torch.tensor([True]) + + caption = unicodedata.normalize("NFKC", convert(caption, "zh-hans")) + if type(self.bpe).__name__ == 'GPT2BPE': + caption_token_list = caption.lower().strip().split() + tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length]) + elif type(self.bpe).__name__ == 'BertBPE': + tgt_caption = caption[: self.max_tgt_length].lower() + src_item = self.encode_text(self.prompt) + tgt_item = self.encode_text(" {}".format(tgt_caption)) + + src_item = torch.cat([self.bos_item, src_item, self.eos_item]) + target_item = torch.cat([tgt_item, self.eos_item]) + prev_output_item = torch.cat([self.bos_item, tgt_item]) + + example = { + "id": uniq_id, + "source": src_item, + "patch_image": patch_image, + "patch_mask": patch_mask, + "target": target_item, + "prev_output_tokens": prev_output_item, + } + return example + + def collater(self, samples, pad_to_length=None): + """Merge a list of samples to form a mini-batch. + Args: + samples (List[dict]): samples to collate + Returns: + dict: a mini-batch containing the data required for the task + """ + return collate(samples, pad_idx=self.pad, eos_idx=self.eos) diff --git a/requirements.txt b/requirements.txt index 00e0135c..52a98e96 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,6 @@ datasets rouge_score soundfile editdistance -librosa \ No newline at end of file +librosa +python-Levenshtein +zhconv \ No newline at end of file diff --git a/tasks/mm_tasks/ocr.py b/tasks/mm_tasks/ocr.py new file mode 100644 index 00000000..96d543b3 --- /dev/null +++ b/tasks/mm_tasks/ocr.py @@ -0,0 +1,146 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. + +import torch +from dataclasses import dataclass, field +import json +import logging +from typing import Optional +from argparse import Namespace +import Levenshtein +from fairseq import metrics, utils +from fairseq.tasks import register_task + +from tasks.ofa_task import OFATask, OFAConfig +from data.mm_data.ocr_dataset import OcrDataset +from data.file_dataset import FileDataset + +EVAL_BLEU_ORDER = 4 + +logger = logging.getLogger(__name__) + + +@dataclass +class OcrConfig(OFAConfig): + is_document: bool = field( + default=False, metadata={"help": "enable special resizing for document data."} + ) + eval_args: Optional[str] = field( + default="{}", + metadata={ + "help": 'generation args, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' + }, + ) + + +@register_task("ocr", dataclass=OcrConfig) +class OcrTask(OFATask): + def __init__(self, cfg: OcrConfig, src_dict, tgt_dict): + super().__init__(cfg, src_dict, tgt_dict) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + paths = self.cfg.data.split(",") + assert len(paths) > 0 + + if split == 'train': + file_path = paths[(epoch - 1) % (len(paths) - 1)] + else: + file_path = paths[-1] + dataset = FileDataset(file_path, self.cfg.selected_cols) + + self.datasets[split] = OcrDataset( + split, + dataset, + self.bpe, + self.src_dict, + self.tgt_dict, + max_src_length=self.cfg.max_src_length, + max_tgt_length=self.cfg.max_tgt_length, + patch_image_size=self.cfg.patch_image_size, + imagenet_default_mean_and_std=self.cfg.imagenet_default_mean_and_std, + is_document=self.cfg.is_document, + ) + + def build_model(self, cfg): + model = super().build_model(cfg) + + gen_args = json.loads(self.cfg.eval_args) + self.sequence_generator = self.build_generator([model], Namespace(**gen_args)) + + return model + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = criterion(model, sample) + + model.eval() + hyps, refs = self._inference(self.sequence_generator, sample, model) + acc = [1.0 if hyp == ref else 0.0 for hyp, ref in zip(hyps, refs)] + distance = [ + Levenshtein.distance(hyp, ref) / max(len(hyp), len(ref)) + for hyp, ref in zip(hyps, refs) + ] + logging_output["_acc_sum"] = sum(acc) + logging_output["_acc_cnt"] = len(acc) + logging_output["_dist_sum"] = sum(distance) + logging_output["_dist_cnt"] = len(distance) + + return loss, sample_size, logging_output + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + def sum_logs(key): + result = sum(log.get(key, 0) for log in logging_outputs) + if torch.is_tensor(result): + result = result.cpu() + return result + + def compute_acc(meters): + score = meters["_acc_sum"].sum / meters["_acc_cnt"].sum + score = score if isinstance(score, float) else score.item() + return round(score, 4) + + def compute_ned(meters): + score = meters["_dist_sum"].sum / meters["_dist_cnt"].sum + score = score if isinstance(score, float) else score.item() + score = 1.0 - score + return round(score, 4) + + if sum_logs("_acc_cnt") > 0: + metrics.log_scalar("_acc_sum", sum_logs("_acc_sum")) + metrics.log_scalar("_acc_cnt", sum_logs("_acc_cnt")) + metrics.log_derived("acc", compute_acc) + metrics.log_scalar("_dist_sum", sum_logs("_dist_sum")) + metrics.log_scalar("_dist_cnt", sum_logs("_dist_cnt")) + metrics.log_derived("ned", compute_ned) + + def _inference(self, generator, sample, model): + def decode(toks, escape_unk=False): + s = self.tgt_dict.string( + toks.int().cpu(), + unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), + ) + if self.bpe: + s = self.bpe.decode(s) + return s + + gen_out = self.inference_step(generator, [model], sample) + hyps, refs = [], [] + for i in range(len(gen_out)): + decode_tokens = decode(gen_out[i][0]["tokens"]) + hyps.append(decode_tokens.strip().replace(" ", "")) + refs.append( + decode( + utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), + escape_unk=True, + ) + .strip() + .replace(" ", "") + ) + if self.cfg.eval_print_samples: + logger.info("example hypothesis: " + hyps[0]) + logger.info("example reference: " + ' && '.join(refs[0])) + + return hyps, refs \ No newline at end of file diff --git a/utils/eval_utils.py b/utils/eval_utils.py index a621004c..a8bd9f7c 100644 --- a/utils/eval_utils.py +++ b/utils/eval_utils.py @@ -11,6 +11,7 @@ import torch import torch.distributed as dist +from fairseq import utils from data import data_utils from tasks.nlg_tasks.gigaword import fix_tokenization @@ -54,6 +55,51 @@ def eval_caption(task, generator, models, sample, **kwargs): return results, None +def eval_caption_cn(task, generator, models, sample, **kwargs): + hypos = task.inference_step(generator, models, sample) + results = [] + for i, sample_id in enumerate(sample["id"].tolist()): + detok_hypo_str = decode_fn( + hypos[i][0]["tokens"], task.tgt_dict, task.bpe, generator + ) + results.append( + { + "image_id": str(sample_id), + "caption": detok_hypo_str.strip(), + } + ) + return results, None + + +def eval_ocr(task, generator, models, sample, **kwargs): + gen_out = task.inference_step(generator, models, sample) + hyps, refs, results = [], [], [] + for i, sample_id in enumerate(sample["id"].tolist()): + decode_tokens = decode_fn(gen_out[i][0]["tokens"], task.tgt_dict, task.bpe, generator).strip() + hyps.append(decode_tokens.strip().replace(" ", "")) + if sample["target"]: + refs.append( + decode_fn( + utils.strip_pad(sample["target"][i], task.tgt_dict.pad()), + task.tgt_dict, task.bpe, generator + ) + .strip() + .replace(" ", "") + ) + results.append( + { + "image_id": str(sample_id), + "ocr": decode_tokens.strip().replace(" ", ""), + } + ) + if refs: + acc = [1.0 if hyp == ref else 0.0 for hyp, ref in zip(hyps, refs)] + else: + acc = None + + return results, acc + + def eval_vqa_gen(task, generator, models, sample, **kwargs): if kwargs['beam_search_vqa_eval']: hypos = task.inference_step(generator, models, sample, prefix_tokens=sample['prefix_tokens']) @@ -330,6 +376,10 @@ def eval_asr(task, generator, models, sample, **kwargs): def eval_step(task, generator, models, sample, **kwargs): if task.cfg._name == 'caption': return eval_caption(task, generator, models, sample, **kwargs) + elif task.cfg._name == "caption_cn": + return eval_caption_cn(task, generator, models, sample, **kwargs) + elif task.cfg._name == "ocr": + return eval_ocr(task, generator, models, sample, **kwargs) elif task.cfg._name == 'vqa_gen': return eval_vqa_gen(task, generator, models, sample, **kwargs) elif task.cfg._name == 'refcoco':