Skip to content

add ABINet [WIP, don't merge yet] #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 313 additions & 0 deletions configs/rec/abinet/README.md

Large diffs are not rendered by default.

327 changes: 327 additions & 0 deletions configs/rec/abinet/README_CN.md

Large diffs are not rendered by default.

110 changes: 110 additions & 0 deletions configs/rec/abinet/abinet_resnet45_en.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: True
amp_level: 'O0'
seed: 42
log_interval: 100
val_while_train: False
drop_overflow_update: False

common:
character_dict_path: &character_dict_path
num_classes: &num_classes 37
max_text_len: &max_text_len 25
infer_mode: &infer_mode False
use_space_char: &use_space_char False
batch_size: &batch_size 96

model:
type: rec
pretrained : "./tmp_rec/pretrain.ckpt"
transform: null
backbone:
name: abinet_backbone
pretrained: False
batchsize: *batch_size
head:
name: ABINetHead
batchsize: *batch_size

postprocess:
name: ABINetLabelDecode

metric:
name: RecMetric
main_indicator: acc
character_dict_path: *character_dict_path
ignore_space: True
print_flag: False
filter_ood: False

loss:
name: ABINetLoss


scheduler:
scheduler: step_decay
decay_rate: 0.1
decay_epochs: 6
warmup_epochs: 0
lr: 0.0001
num_epochs : 10


optimizer:
opt: adam


train:
clip_grad: True
clip_norm: 20.0
ckpt_save_dir: './tmp_rec'
dataset_sink_mode: False
dataset:
type: LMDBDataset
dataset_root: path/to/data_lmdb_release/
data_dir: train/
# label_files: # not required when using LMDBDataset
sample_ratio: 1.0
shuffle: True
transform_pipeline:
- ABINetTransforms:
- ABINetRecAug:
- NormalizeImage:
is_hwc: False
mean: [0.485, 0.456, 0.406]
std: [0.485, 0.456, 0.406]
# # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_columns: ['image','label','length','label_for_mask'] #'img_path']

loader:
shuffle: True # TODO: tbc
batch_size: *batch_size
drop_remainder: True
max_rowsize: 128
num_workers: 20

eval:
ckpt_load_path: ./tmp_rec/best.ckpt
dataset_sink_mode: False
dataset:
type: LMDBDataset
dataset_root: path/to/data_lmdb_release/
data_dir: evaluation/
# label_files: # not required when using LMDBDataset
sample_ratio: 1.0
shuffle: False
transform_pipeline:
- ABINetEvalTransforms:
- ABINetEval:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_columns: ['image','label','length','label_for_mask'] # TODO return text string padding w/ fixed length, and a scaler to indicate the length
net_input_column_index: [0] # input indices for network forward func in output_columns
label_column_index: [1, 2] # input indices marked as label

loader:
shuffle: False # TODO: tbc
batch_size: *batch_size
drop_remainder: False
max_rowsize: 128
num_workers: 8
226 changes: 226 additions & 0 deletions mindocr/data/transforms/rec_abinet_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""
transform for text recognition tasks.
"""
import copy
import logging
import random
import re
import warnings

import cv2
import numpy as np
import PIL
import six
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use six! six is needed for backward compatibility with python 2.x. Our minimum supported version of python is 3.7!

from PIL import Image

import mindspore.dataset as ds

from ...models.utils.abinet_layers import CharsetMapper, onehot
from .svtr_transform import (
CVColorJitter,
CVGaussianNoise,
CVMotionBlur,
CVRandomAffine,
CVRandomPerspective,
CVRandomRotation,
CVRescale,
)

_logger = logging.getLogger(__name__)
__all__ = ["ABINetTransforms", "ABINetRecAug", "ABINetEval", "ABINetEvalTransforms"]


class ABINetTransforms(object):
"""Convert text label (str) to a sequence of character indices according to the char dictionary

Args:

"""

def __init__(
self,
):
# ABINet_Transforms
self.case_sensitive = False
self.charset = CharsetMapper(max_length=26)

def __call__(self, data: dict):
img_lmdb = data["img_lmdb"]
label = data["label"]
label = label.encode("utf-8")
label = str(label, "utf-8")
try:
label = re.sub("[^0-9a-zA-Z]+", "", label)
if len(label) > 25 or len(label) <= 0:
string_false2 = f"len(label) > 25 or len(label) <= 0: {label}, {len(label)}"
_logger.warning(string_false2)
label = label[:25]
buf = six.BytesIO()
buf.write(img_lmdb)
buf.seek(0)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
image = PIL.Image.open(buf).convert("RGB")
if not _check_image(image, pixels=6):
string_false1 = f"_check_image false: {label}, {len(label)}"
_logger.warning(string_false1)
except Exception:
string_false = f"Corrupted image is found: {label}, {len(label)}"
_logger.warning(string_false)

image = np.array(image)

text = label

length = len(text) + 1
length = float(length)

label = self.charset.get_labels(text, case_sensitive=self.case_sensitive)
label_for_mask = copy.deepcopy(label)
label_for_mask[int(length - 1)] = 1
label = onehot(label, self.charset.num_classes)
data_dict = {"image": image, "label": label, "length": length, "label_for_mask": label_for_mask}
return data_dict


class ABINetRecAug(object):
def __init__(self):
self.transforms = ds.transforms.Compose(
[
CVGeometry(
degrees=45,
translate=(0.0, 0.0),
scale=(0.5, 2.0),
shear=(45, 15),
distortion=0.5,
p=0.5,
),
CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25),
]
)
self.toTensor = ds.vision.ToTensor()
self.w = 128
self.h = 32

def __call__(self, data):
img = data["image"]
img = self.transforms(img)
img = cv2.resize(img, (self.w, self.h))
img = self.toTensor(img)
data["image"] = img
return data


def _check_image(x, pixels=6):
if x.size[0] <= pixels or x.size[1] <= pixels:
return False
else:
return True


class ABINetEvalTransforms(object):
"""Convert text label (str) to a sequence of character indices according to the char dictionary

Args:

"""

def __init__(
self,
):
# ABINet_Transforms
self.case_sensitive = False
self.charset = CharsetMapper(max_length=26)

def __call__(self, data: dict):
img_lmdb = data["img_lmdb"]
label = data["label"]
label = label.encode("utf-8")
label = str(label, "utf-8")
try:
label = re.sub("[^0-9a-zA-Z]+", "", label)
if len(label) > 25 or len(label) <= 0:
string_false2 = f"en(label) > 25 or len(label) <= 0: {label}, {len(label)}"
_logger.warning(string_false2)
label = label[:25]
buf = six.BytesIO()
buf.write(img_lmdb)
buf.seek(0)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
image = PIL.Image.open(buf).convert("RGB")
if not _check_image(image, pixels=6):
string_false1 = f"_check_image false: {label}, {len(label)}"
_logger.warning(string_false1)
except Exception:
string_false = f"Corrupted image is found: {label}, {len(label)}"
_logger.warning(string_false)

image = np.array(image)

text = label
length = len(text) + 1
length = float(length)
data_dict = {"image": image, "label": text, "length": length}
return data_dict


class ABINetEval(object):
def __init__(self):
self.toTensor = ds.vision.ToTensor()
self.w = 128
self.h = 32

def __call__(self, data):
img = data["image"]
img = cv2.resize(img, (self.w, self.h))
img = self.toTensor(img)
data["image"] = img
length = data["length"]
length = int(length)
data["length"] = length
return data


class CVGeometry(object):
def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.0), shear=(45, 15), distortion=0.5, p=0.5):
self.p = p
type_p = random.random()
if type_p < 0.33:
self.transforms = CVRandomRotation(degrees=degrees)
elif type_p < 0.66:
self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
else:
self.transforms = CVRandomPerspective(distortion=distortion)

def __call__(self, img):
if random.random() < self.p:
img = np.array(img)
return Image.fromarray(self.transforms(img))
else:
return img


class CVDeterioration(object):
def __init__(self, var, degrees, factor, p=0.5):
self.p = p
transforms = []
if var is not None:
transforms.append(CVGaussianNoise(var=var))
if degrees is not None:
transforms.append(CVMotionBlur(degrees=degrees))
if factor is not None:
transforms.append(CVRescale(factor=factor))

random.shuffle(transforms)

transforms = ds.transforms.Compose(transforms)
self.transforms = transforms

def __call__(self, img):
if random.random() < self.p:
img = np.array(img)
return Image.fromarray(self.transforms(img))
else:
return img
1 change: 1 addition & 0 deletions mindocr/data/transforms/transforms_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .det_fce_transforms import *
from .det_transforms import *
from .general_transforms import *
from .rec_abinet_transforms import *
from .rec_transforms import *
from .svtr_transform import *

Expand Down
Loading