-
Notifications
You must be signed in to change notification settings - Fork 248
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #298 from OFA-Sys/feature/add_text_recognition
Feature/add text recognition
- Loading branch information
Showing
4 changed files
with
403 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# Copyright 2022 The OFA-Sys Team. | ||
# All rights reserved. | ||
# This source code is licensed under the Apache 2.0 license | ||
# found in the LICENSE file in the root directory. | ||
|
||
from io import BytesIO | ||
|
||
import logging | ||
import warnings | ||
import random | ||
import functools | ||
|
||
import torch | ||
import base64 | ||
from torchvision import transforms | ||
from torchvision.transforms import InterpolationMode | ||
from torchvision.transforms import functional as F | ||
|
||
from PIL import Image, ImageFile | ||
|
||
from zhconv import convert | ||
import unicodedata | ||
|
||
from data import data_utils | ||
from data.ofa_dataset import OFADataset | ||
|
||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
ImageFile.MAX_IMAGE_PIXELS = None | ||
Image.MAX_IMAGE_PIXELS = None | ||
|
||
logger = logging.getLogger(__name__) | ||
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) | ||
|
||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | ||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | ||
|
||
|
||
def collate(samples, pad_idx, eos_idx): | ||
if len(samples) == 0: | ||
return {} | ||
|
||
def merge(key): | ||
return data_utils.collate_tokens( | ||
[s[key] for s in samples], | ||
pad_idx, | ||
eos_idx=eos_idx, | ||
) | ||
|
||
id = np.array([s["id"] for s in samples]) | ||
src_tokens = merge("source") | ||
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples]) | ||
|
||
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0) | ||
patch_masks = torch.cat([sample['patch_mask'] for sample in samples]) | ||
|
||
prev_output_tokens = None | ||
target = None | ||
if samples[0].get("target", None) is not None: | ||
target = merge("target") | ||
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples]) | ||
ntokens = tgt_lengths.sum().item() | ||
|
||
if samples[0].get("prev_output_tokens", None) is not None: | ||
prev_output_tokens = merge("prev_output_tokens") | ||
else: | ||
ntokens = src_lengths.sum().item() | ||
|
||
batch = { | ||
"id": id, | ||
"nsentences": len(samples), | ||
"ntokens": ntokens, | ||
"net_input": { | ||
"src_tokens": src_tokens, | ||
"src_lengths": src_lengths, | ||
"patch_images": patch_images, | ||
"patch_masks": patch_masks, | ||
"prev_output_tokens": prev_output_tokens | ||
}, | ||
"target": target, | ||
} | ||
|
||
return batch | ||
|
||
|
||
def ocr_resize(img, patch_image_size, is_document=False): | ||
img = img.convert("RGB") | ||
width, height = img.size | ||
|
||
if is_document: | ||
new_height, new_width = 64, 1920 | ||
else: | ||
if width >= height: | ||
new_width = max(64, patch_image_size) | ||
new_height = max(64, int(patch_image_size * (height / width))) | ||
top = random.randint(0, patch_image_size - new_height) | ||
bottom = patch_image_size - new_height - top | ||
left, right = 0, 0 | ||
else: | ||
new_height = max(64, patch_image_size) | ||
new_width = max(64, int(patch_image_size * (width / height))) | ||
left = random.randint(0, patch_image_size - new_width) | ||
right = patch_image_size - new_width - left | ||
top, bottom = 0, 0 | ||
|
||
img_new = F.resize( | ||
img, | ||
[new_height, new_width], | ||
interpolation=InterpolationMode.BICUBIC, | ||
) | ||
|
||
if is_document: | ||
img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1) | ||
img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2)) | ||
new_width, new_height = img_new.size | ||
top = random.randint(0, patch_image_size - new_height) | ||
bottom = patch_image_size - new_height - top | ||
left, right = 0, 0 | ||
|
||
img_new = F.pad(img_new, padding=[left, top, right, bottom], padding_mode="edge") | ||
assert img_new.size == (patch_image_size, patch_image_size) | ||
|
||
return img_new | ||
|
||
|
||
class OcrDataset(OFADataset): | ||
def __init__( | ||
self, | ||
split, | ||
dataset, | ||
bpe, | ||
src_dict, | ||
tgt_dict=None, | ||
max_src_length=80, | ||
max_tgt_length=30, | ||
patch_image_size=224, | ||
imagenet_default_mean_and_std=False, | ||
is_document=False, | ||
): | ||
super().__init__(split, dataset, bpe, src_dict, tgt_dict) | ||
self.max_src_length = max_src_length | ||
self.max_tgt_length = max_tgt_length | ||
self.patch_image_size = patch_image_size | ||
|
||
if imagenet_default_mean_and_std: | ||
mean = IMAGENET_DEFAULT_MEAN | ||
std = IMAGENET_DEFAULT_STD | ||
else: | ||
mean = [0.5, 0.5, 0.5] | ||
std = [0.5, 0.5, 0.5] | ||
|
||
self.patch_resize_transform = transforms.Compose( | ||
[ | ||
lambda image: ocr_resize( | ||
image, patch_image_size, is_document=is_document | ||
), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=mean, std=std), | ||
] | ||
) | ||
|
||
self.bpe = bpe | ||
if type(bpe).__name__ == 'GPT2BPE': | ||
self.prompt = " what are the texts on the image?" | ||
elif type(bpe).__name__ == 'BertBPE': | ||
self.prompt = "图片上的文字是什么?" | ||
|
||
def __getitem__(self, index): | ||
uniq_id, image, caption = self.dataset[index] | ||
|
||
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))) | ||
patch_image = self.patch_resize_transform(image) | ||
patch_mask = torch.tensor([True]) | ||
|
||
caption = unicodedata.normalize("NFKC", convert(caption, "zh-hans")) | ||
if type(self.bpe).__name__ == 'GPT2BPE': | ||
caption_token_list = caption.lower().strip().split() | ||
tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length]) | ||
elif type(self.bpe).__name__ == 'BertBPE': | ||
tgt_caption = caption[: self.max_tgt_length].lower() | ||
src_item = self.encode_text(self.prompt) | ||
tgt_item = self.encode_text(" {}".format(tgt_caption)) | ||
|
||
src_item = torch.cat([self.bos_item, src_item, self.eos_item]) | ||
target_item = torch.cat([tgt_item, self.eos_item]) | ||
prev_output_item = torch.cat([self.bos_item, tgt_item]) | ||
|
||
example = { | ||
"id": uniq_id, | ||
"source": src_item, | ||
"patch_image": patch_image, | ||
"patch_mask": patch_mask, | ||
"target": target_item, | ||
"prev_output_tokens": prev_output_item, | ||
} | ||
return example | ||
|
||
def collater(self, samples, pad_to_length=None): | ||
"""Merge a list of samples to form a mini-batch. | ||
Args: | ||
samples (List[dict]): samples to collate | ||
Returns: | ||
dict: a mini-batch containing the data required for the task | ||
""" | ||
return collate(samples, pad_idx=self.pad, eos_idx=self.eos) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,6 @@ pycocoevalcap==1.2 | |
pytorch_lightning | ||
einops | ||
datasets | ||
rouge_score | ||
rouge_score | ||
python-Levenshtein | ||
zhconv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright 2022 The OFA-Sys Team. | ||
# All rights reserved. | ||
# This source code is licensed under the Apache 2.0 license | ||
# found in the LICENSE file in the root directory. | ||
|
||
import torch | ||
from dataclasses import dataclass, field | ||
import json | ||
import logging | ||
from typing import Optional | ||
from argparse import Namespace | ||
import Levenshtein | ||
from fairseq import metrics, utils | ||
from fairseq.tasks import register_task | ||
|
||
from tasks.ofa_task import OFATask, OFAConfig | ||
from data.mm_data.ocr_dataset import OcrDataset | ||
from data.file_dataset import FileDataset | ||
|
||
EVAL_BLEU_ORDER = 4 | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class OcrConfig(OFAConfig): | ||
is_document: bool = field( | ||
default=False, metadata={"help": "enable special resizing for document data."} | ||
) | ||
eval_args: Optional[str] = field( | ||
default="{}", | ||
metadata={ | ||
"help": 'generation args, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' | ||
}, | ||
) | ||
|
||
|
||
@register_task("ocr", dataclass=OcrConfig) | ||
class OcrTask(OFATask): | ||
def __init__(self, cfg: OcrConfig, src_dict, tgt_dict): | ||
super().__init__(cfg, src_dict, tgt_dict) | ||
|
||
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | ||
paths = self.cfg.data.split(",") | ||
assert len(paths) > 0 | ||
|
||
if split == 'train': | ||
file_path = paths[(epoch - 1) % (len(paths) - 1)] | ||
else: | ||
file_path = paths[-1] | ||
dataset = FileDataset(file_path, self.cfg.selected_cols) | ||
|
||
self.datasets[split] = OcrDataset( | ||
split, | ||
dataset, | ||
self.bpe, | ||
self.src_dict, | ||
self.tgt_dict, | ||
max_src_length=self.cfg.max_src_length, | ||
max_tgt_length=self.cfg.max_tgt_length, | ||
patch_image_size=self.cfg.patch_image_size, | ||
imagenet_default_mean_and_std=self.cfg.imagenet_default_mean_and_std, | ||
is_document=self.cfg.is_document, | ||
) | ||
|
||
def build_model(self, cfg): | ||
model = super().build_model(cfg) | ||
|
||
gen_args = json.loads(self.cfg.eval_args) | ||
self.sequence_generator = self.build_generator([model], Namespace(**gen_args)) | ||
|
||
return model | ||
|
||
def valid_step(self, sample, model, criterion): | ||
loss, sample_size, logging_output = criterion(model, sample) | ||
|
||
model.eval() | ||
hyps, refs = self._inference(self.sequence_generator, sample, model) | ||
acc = [1.0 if hyp == ref else 0.0 for hyp, ref in zip(hyps, refs)] | ||
distance = [ | ||
Levenshtein.distance(hyp, ref) / max(len(hyp), len(ref)) | ||
for hyp, ref in zip(hyps, refs) | ||
] | ||
logging_output["_acc_sum"] = sum(acc) | ||
logging_output["_acc_cnt"] = len(acc) | ||
logging_output["_dist_sum"] = sum(distance) | ||
logging_output["_dist_cnt"] = len(distance) | ||
|
||
return loss, sample_size, logging_output | ||
|
||
def reduce_metrics(self, logging_outputs, criterion): | ||
super().reduce_metrics(logging_outputs, criterion) | ||
|
||
def sum_logs(key): | ||
result = sum(log.get(key, 0) for log in logging_outputs) | ||
if torch.is_tensor(result): | ||
result = result.cpu() | ||
return result | ||
|
||
def compute_acc(meters): | ||
score = meters["_acc_sum"].sum / meters["_acc_cnt"].sum | ||
score = score if isinstance(score, float) else score.item() | ||
return round(score, 4) | ||
|
||
def compute_ned(meters): | ||
score = meters["_dist_sum"].sum / meters["_dist_cnt"].sum | ||
score = score if isinstance(score, float) else score.item() | ||
score = 1.0 - score | ||
return round(score, 4) | ||
|
||
if sum_logs("_acc_cnt") > 0: | ||
metrics.log_scalar("_acc_sum", sum_logs("_acc_sum")) | ||
metrics.log_scalar("_acc_cnt", sum_logs("_acc_cnt")) | ||
metrics.log_derived("acc", compute_acc) | ||
metrics.log_scalar("_dist_sum", sum_logs("_dist_sum")) | ||
metrics.log_scalar("_dist_cnt", sum_logs("_dist_cnt")) | ||
metrics.log_derived("ned", compute_ned) | ||
|
||
def _inference(self, generator, sample, model): | ||
def decode(toks, escape_unk=False): | ||
s = self.tgt_dict.string( | ||
toks.int().cpu(), | ||
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), | ||
) | ||
if self.bpe: | ||
s = self.bpe.decode(s) | ||
return s | ||
|
||
gen_out = self.inference_step(generator, [model], sample) | ||
hyps, refs = [], [] | ||
for i in range(len(gen_out)): | ||
decode_tokens = decode(gen_out[i][0]["tokens"]) | ||
hyps.append(decode_tokens.strip().replace(" ", "")) | ||
refs.append( | ||
decode( | ||
utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), | ||
escape_unk=True, | ||
) | ||
.strip() | ||
.replace(" ", "") | ||
) | ||
if self.cfg.eval_print_samples: | ||
logger.info("example hypothesis: " + hyps[0]) | ||
logger.info("example reference: " + ' && '.join(refs[0])) | ||
|
||
return hyps, refs |
Oops, something went wrong.