Skip to content

Add RobustScanner rec model #444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/rec/crnn/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ eval:
```

**注意:**
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或将学习率线性调整为新的全局批大小
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或根据新的全局批大小线性调整学习率


### 3.2 模型训练
Expand Down
2 changes: 1 addition & 1 deletion configs/rec/master/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ eval:
```

**注意:**
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或将学习率线性调整为新的全局批大小
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或根据新的全局批大小线性调整学习率


### 3.2 模型训练
Expand Down
2 changes: 1 addition & 1 deletion configs/rec/rare/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ eval:
```

**注意:**
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或将学习率线性调整为新的全局批大小
- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或根据新的全局批大小线性调整学习率


### 3.2 模型训练
Expand Down
389 changes: 389 additions & 0 deletions configs/rec/robustscanner/README.md

Large diffs are not rendered by default.

394 changes: 394 additions & 0 deletions configs/rec/robustscanner/README_CN.md

Large diffs are not rendered by default.

140 changes: 140 additions & 0 deletions configs/rec/robustscanner/robustscanner_resnet31.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: True
amp_level: 'O0'
seed: 42
log_interval: 100
val_while_train: True
drop_overflow_update: False

common:
character_dict_path: &character_dict_path mindocr/utils/dict/en_dict90.txt
max_text_len: &max_text_len 40
use_space_char: &use_space_char False
batch_size: &batch_size 64

model:
type: rec
transform: null
backbone:
name: rec_resnet31
pretrained: False
head:
name: RobustScannerHead
out_channels: 93 # 90 + unknown + start + padding
enc_outchannles: 128
hybrid_dec_rnn_layers: 2
hybrid_dec_dropout: 0.
position_dec_rnn_layers: 2
start_idx: 91
mask: True
padding_idx: 92
encode_value: False
max_text_len: *max_text_len

postprocess:
name: SARLabelDecode
character_dict_path: *character_dict_path
use_space_char: *use_space_char
rm_symbol: True

metric:
name: RecMetric
main_indicator: acc
character_dict_path: *character_dict_path
ignore_space: True
print_flag: False

loss:
name: SARLoss
ignore_index: 92

scheduler:
scheduler: multi_step_decay
milestones: [6, 8]
decay_rate: 0.1
lr: 0.001
num_epochs: 10
warmup_epochs: 0

optimizer:
opt: adamW
beta1: 0.9
beta2: 0.999

loss_scaler:
type: static
loss_scale: 512

train:
ema: True
ckpt_save_dir: './tmp_rec'
dataset_sink_mode: False
dataset:
type: LMDBDataset
dataset_root: path/to/data/ # Optional, if set, dataset_root will be used as a prefix for data_dir
data_dir: training/
# label_files: # not required when using LMDBDataset
sample_ratio: 1.0
shuffle: True
random_choice_if_none: True # Random choose another data if the result returned from data transform is none
transform_pipeline:
- DecodeImage:
img_mode: BGR
to_float32: False
- SARLabelEncode: # Class handling label
max_text_len: *max_text_len
character_dict_path: *character_dict_path
use_space_char: *use_space_char
lower: True
- RobustScannerRecResizeImg:
image_shape: [ 3, 48, 48, 160 ] # h:48 w:[48,160]
width_downsample_ratio: 0.25
max_text_len: *max_text_len
output_columns: ['image', 'label', 'valid_width_mask', 'word_positions']
net_input_column_index: [0, 1, 2, 3] # input indices for network forward func in output_columns
label_column_index: [1] # input indices marked as label
#keys_for_loss: 4 # num labels for loss func

loader:
shuffle: True # TODO: tbc
batch_size: *batch_size
drop_remainder: True
max_rowsize: 12
num_workers: 8

eval:
ckpt_load_path: ./tmp_rec/best.ckpt
dataset_sink_mode: False
dataset:
type: LMDBDataset
dataset_root: path/to/data/
data_dir: evaluation/
# label_files: # not required when using LMDBDataset
sample_ratio: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
img_mode: BGR
to_float32: False
- SARLabelEncode: # Class handling label
max_text_len: *max_text_len
# character_dict_path: *character_dict_path
use_space_char: *use_space_char
is_training: False
lower: True
- RobustScannerRecResizeImg:
image_shape: [ 3, 48, 48, 160 ]
width_downsample_ratio: 0.25
max_text_len: *max_text_len
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_columns: [ 'image', 'valid_width_mask', 'word_positions', 'text_padded', 'text_length' ]
net_input_column_index: [ 0, 1, 2 ] # input indices for network forward func in output_columns
label_column_index: [3, 4]

loader:
shuffle: False # TODO: tbc
batch_size: 64
drop_remainder: True
max_rowsize: 12
num_workers: 8
20 changes: 19 additions & 1 deletion mindocr/data/rec_lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class LMDBDataset(BaseDataset):
if None, all data keys will be used for return.
filter_max_len (bool): Filter the records where the label is longer than the `max_text_len`.
max_text_len (int): The maximum text length the dataloader expected.
random_choice_if_none (bool): Random choose another data if the result returned from data transform is none

Returns:
data (tuple): Depending on the transform pipeline, __get_item__ returns a tuple for the specified data item.
Expand Down Expand Up @@ -54,11 +55,13 @@ def __init__(
output_columns: Optional[List[str]] = None,
filter_max_len: bool = False,
max_text_len: Optional[int] = None,
random_choice_if_none: bool = False,
**kwargs: Any,
):
self.data_dir = data_dir
self.filter_max_len = filter_max_len
self.max_text_len = max_text_len
self.random_choice_if_none = random_choice_if_none

shuffle = shuffle if shuffle is not None else is_train

Expand Down Expand Up @@ -197,10 +200,25 @@ def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[int(lmdb_idx)]["txn"], int(file_idx))

if sample_info is None and self.random_choice_if_none:
_logger.warning("sample_info is None, randomly choose another data.")
random_idx = np.random.randint(self.__len__())
return self.__getitem__(random_idx)

data = {"img_lmdb": sample_info[0], "label": sample_info[1]}

# perform transformation on data
data = run_transforms(data, transforms=self.transforms)
try:
data = run_transforms(data, transforms=self.transforms)
except Exception as e:
if self.random_choice_if_none:
_logger.warning("data is None after transforms, randomly choose another data.")
random_idx = np.random.randint(self.__len__())
return self.__getitem__(random_idx)
else:
_logger.warning(f"Error occurred during preprocess.\n {e}")
raise e

output_tuple = tuple(data[k] for k in self.output_columns)

return output_tuple
Expand Down
154 changes: 154 additions & 0 deletions mindocr/data/transforms/rec_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
"SVTRRecResizeImg",
"Rotate90IfVertical",
"ClsLabelEncode",
"SARLabelEncode",
"RobustScannerRecResizeImg",
]
_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -665,3 +667,155 @@ def __call__(self, data):
data["label"] = label

return data


class SARLabelEncode(object):
"""Convert between text-label and text-index"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SARLabelEncode 跟 RecAttnLabelEncode 非常相似,可以考虑合并

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已查看,SARLabelEncode 和 RecAttnLabelEncode 对特殊情况的返回None的处理不同,使用的special_char不同,不建议合并


def __init__(self, max_text_len, character_dict_path=None, use_space_char=False, lower=False, is_training=True):
self.max_text_len = max_text_len
self.beg_str = "sos"
self.end_str = "eos"
self.lower = lower
self.is_training = is_training

if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
self.lower = True
if self.is_training:
_logger.warning("The character_dict_path is None, model can only recognize number and lower letters")
else:
self.character_str = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode("utf-8").strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character

def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]

output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if len(text) == 0 or len(text) > self.max_text_len:
return None
if self.lower:
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
continue
text_list.append(self.dict[char])
if len(text_list) == 0:
return None
return text_list

def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1

return dict_character

def __call__(self, data):
text = data["label"]
text_str = text
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
data["text_length"] = np.array(len(text))
target = [self.start_idx] + text + [self.end_idx]
padded_text = [self.padding_idx for _ in range(self.max_text_len)]

padded_text[: len(target)] = target
data["label"] = np.array(padded_text)
data["text_padded"] = text_str + " " * (self.max_text_len - len(text_str))

return data

def get_ignored_tokens(self):
return [self.padding_idx]


class RobustScannerRecResizeImg(object):
def __init__(self, image_shape, max_text_len, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
self.max_text_len = max_text_len

def __call__(self, data):
img = data["image"]
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
img, self.image_shape, self.width_downsample_ratio
)
valid_ratio = np.array(valid_ratio, dtype=np.float32)
width_downsampled = int(self.image_shape[-1] * self.width_downsample_ratio)
valid_width_mask = np.full([1, width_downsampled], 1)
valid_width = min(width_downsampled, int(width_downsampled * valid_ratio + 0.5))
valid_width_mask[:, valid_width:] = 0
word_positons = np.array(range(0, self.max_text_len)).astype("int64")
data["image"] = norm_img
data["resized_shape"] = resize_shape
data["pad_shape"] = pad_shape
data["valid_ratio"] = valid_ratio
data["valid_width_mask"] = valid_width_mask
data["word_positions"] = word_positons
return data


def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype("float32")
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape

return padding_im, resize_shape, pad_shape, valid_ratio
2 changes: 1 addition & 1 deletion mindocr/data/transforms/transforms_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def run_transforms(data, transforms=None, verbose=False):
)

if data is None:
raise RuntimeError("Empty result is returned from transform `{transform}`")
raise RuntimeError(f"Empty result is returned from transform `{transform}`")
return data


Expand Down
Loading