Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Maggione committed Dec 9, 2022
2 parents 4cd2a78 + 5fd17e7 commit 70e35fb
Show file tree
Hide file tree
Showing 4 changed files with 403 additions and 1 deletion.
204 changes: 204 additions & 0 deletions data/mm_data/ocr_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

from io import BytesIO

import logging
import warnings
import random
import functools

import torch
import base64
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F

from PIL import Image, ImageFile

from zhconv import convert
import unicodedata

from data import data_utils
from data.ofa_dataset import OFADataset

ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None

logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}

def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)

id = np.array([s["id"] for s in samples])
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])

patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])

prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
ntokens = tgt_lengths.sum().item()

if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()

batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"patch_images": patch_images,
"patch_masks": patch_masks,
"prev_output_tokens": prev_output_tokens
},
"target": target,
}

return batch


def ocr_resize(img, patch_image_size, is_document=False):
img = img.convert("RGB")
width, height = img.size

if is_document:
new_height, new_width = 64, 1920
else:
if width >= height:
new_width = max(64, patch_image_size)
new_height = max(64, int(patch_image_size * (height / width)))
top = random.randint(0, patch_image_size - new_height)
bottom = patch_image_size - new_height - top
left, right = 0, 0
else:
new_height = max(64, patch_image_size)
new_width = max(64, int(patch_image_size * (width / height)))
left = random.randint(0, patch_image_size - new_width)
right = patch_image_size - new_width - left
top, bottom = 0, 0

img_new = F.resize(
img,
[new_height, new_width],
interpolation=InterpolationMode.BICUBIC,
)

if is_document:
img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1)
img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2))
new_width, new_height = img_new.size
top = random.randint(0, patch_image_size - new_height)
bottom = patch_image_size - new_height - top
left, right = 0, 0

img_new = F.pad(img_new, padding=[left, top, right, bottom], padding_mode="edge")
assert img_new.size == (patch_image_size, patch_image_size)

return img_new


class OcrDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=80,
max_tgt_length=30,
patch_image_size=224,
imagenet_default_mean_and_std=False,
is_document=False,
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.patch_image_size = patch_image_size

if imagenet_default_mean_and_std:
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
else:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

self.patch_resize_transform = transforms.Compose(
[
lambda image: ocr_resize(
image, patch_image_size, is_document=is_document
),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)

self.bpe = bpe
if type(bpe).__name__ == 'GPT2BPE':
self.prompt = " what are the texts on the image?"
elif type(bpe).__name__ == 'BertBPE':
self.prompt = "图片上的文字是什么?"

def __getitem__(self, index):
uniq_id, image, caption = self.dataset[index]

image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
patch_image = self.patch_resize_transform(image)
patch_mask = torch.tensor([True])

caption = unicodedata.normalize("NFKC", convert(caption, "zh-hans"))
if type(self.bpe).__name__ == 'GPT2BPE':
caption_token_list = caption.lower().strip().split()
tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
elif type(self.bpe).__name__ == 'BertBPE':
tgt_caption = caption[: self.max_tgt_length].lower()
src_item = self.encode_text(self.prompt)
tgt_item = self.encode_text(" {}".format(tgt_caption))

src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])

example = {
"id": uniq_id,
"source": src_item,
"patch_image": patch_image,
"patch_mask": patch_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
}
return example

def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data required for the task
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ datasets
rouge_score
soundfile
editdistance
librosa
librosa
python-Levenshtein
zhconv
146 changes: 146 additions & 0 deletions tasks/mm_tasks/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import torch
from dataclasses import dataclass, field
import json
import logging
from typing import Optional
from argparse import Namespace
import Levenshtein
from fairseq import metrics, utils
from fairseq.tasks import register_task

from tasks.ofa_task import OFATask, OFAConfig
from data.mm_data.ocr_dataset import OcrDataset
from data.file_dataset import FileDataset

EVAL_BLEU_ORDER = 4

logger = logging.getLogger(__name__)


@dataclass
class OcrConfig(OFAConfig):
is_document: bool = field(
default=False, metadata={"help": "enable special resizing for document data."}
)
eval_args: Optional[str] = field(
default="{}",
metadata={
"help": 'generation args, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
},
)


@register_task("ocr", dataclass=OcrConfig)
class OcrTask(OFATask):
def __init__(self, cfg: OcrConfig, src_dict, tgt_dict):
super().__init__(cfg, src_dict, tgt_dict)

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
paths = self.cfg.data.split(",")
assert len(paths) > 0

if split == 'train':
file_path = paths[(epoch - 1) % (len(paths) - 1)]
else:
file_path = paths[-1]
dataset = FileDataset(file_path, self.cfg.selected_cols)

self.datasets[split] = OcrDataset(
split,
dataset,
self.bpe,
self.src_dict,
self.tgt_dict,
max_src_length=self.cfg.max_src_length,
max_tgt_length=self.cfg.max_tgt_length,
patch_image_size=self.cfg.patch_image_size,
imagenet_default_mean_and_std=self.cfg.imagenet_default_mean_and_std,
is_document=self.cfg.is_document,
)

def build_model(self, cfg):
model = super().build_model(cfg)

gen_args = json.loads(self.cfg.eval_args)
self.sequence_generator = self.build_generator([model], Namespace(**gen_args))

return model

def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = criterion(model, sample)

model.eval()
hyps, refs = self._inference(self.sequence_generator, sample, model)
acc = [1.0 if hyp == ref else 0.0 for hyp, ref in zip(hyps, refs)]
distance = [
Levenshtein.distance(hyp, ref) / max(len(hyp), len(ref))
for hyp, ref in zip(hyps, refs)
]
logging_output["_acc_sum"] = sum(acc)
logging_output["_acc_cnt"] = len(acc)
logging_output["_dist_sum"] = sum(distance)
logging_output["_dist_cnt"] = len(distance)

return loss, sample_size, logging_output

def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)

def sum_logs(key):
result = sum(log.get(key, 0) for log in logging_outputs)
if torch.is_tensor(result):
result = result.cpu()
return result

def compute_acc(meters):
score = meters["_acc_sum"].sum / meters["_acc_cnt"].sum
score = score if isinstance(score, float) else score.item()
return round(score, 4)

def compute_ned(meters):
score = meters["_dist_sum"].sum / meters["_dist_cnt"].sum
score = score if isinstance(score, float) else score.item()
score = 1.0 - score
return round(score, 4)

if sum_logs("_acc_cnt") > 0:
metrics.log_scalar("_acc_sum", sum_logs("_acc_sum"))
metrics.log_scalar("_acc_cnt", sum_logs("_acc_cnt"))
metrics.log_derived("acc", compute_acc)
metrics.log_scalar("_dist_sum", sum_logs("_dist_sum"))
metrics.log_scalar("_dist_cnt", sum_logs("_dist_cnt"))
metrics.log_derived("ned", compute_ned)

def _inference(self, generator, sample, model):
def decode(toks, escape_unk=False):
s = self.tgt_dict.string(
toks.int().cpu(),
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
)
if self.bpe:
s = self.bpe.decode(s)
return s

gen_out = self.inference_step(generator, [model], sample)
hyps, refs = [], []
for i in range(len(gen_out)):
decode_tokens = decode(gen_out[i][0]["tokens"])
hyps.append(decode_tokens.strip().replace(" ", ""))
refs.append(
decode(
utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
escape_unk=True,
)
.strip()
.replace(" ", "")
)
if self.cfg.eval_print_samples:
logger.info("example hypothesis: " + hyps[0])
logger.info("example reference: " + ' && '.join(refs[0]))

return hyps, refs
Loading

0 comments on commit 70e35fb

Please sign in to comment.