Skip to content

Refactor lmdb dataset loader and resnet backbone #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions configs/rec/crnn/crnn_resnet34.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ train:
dataset_sink_mode: False
dataset:
type: LMDBDataset
data_dir: path/to/datadir/train/
# label_files: /data/ocr_datasets/ic15/word_recognition/rec_gt_train.txt
sample_ratios: [1.0]
dataset_root: path/to/data_lmdb_release/
data_dir: training/
# label_files: # not required when using LMDBDataset
sample_ratios: 1.0
shuffle: True
transform_pipeline:
- DecodeImage:
Expand Down Expand Up @@ -110,9 +111,10 @@ eval:
dataset_sink_mode: False
dataset:
type: LMDBDataset
data_dir: path/to/datadir/validation/
# label_files: /data/ocr_datasets/ic15/word_recognition/rec_gt_train.txt
sample_ratios: [1.0]
dataset_root: path/to/data_lmdb_release/
data_dir: validation/
# label_files: # not required when using LMDBDataset
sample_ratios: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
Expand Down
14 changes: 8 additions & 6 deletions configs/rec/crnn/crnn_vgg7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ train:
dataset_sink_mode: False
dataset:
type: LMDBDataset
data_dir: path/to/datadir/train/
# label_files: /data/ocr_datasets/ic15/word_recognition/rec_gt_train.txt
sample_ratios: [1.0]
dataset_root: path/to/data_lmdb_release/
data_dir: training/
# label_files: # not required when using LMDBDataset
sample_ratios: 1.0
shuffle: True
transform_pipeline:
- DecodeImage:
Expand Down Expand Up @@ -110,9 +111,10 @@ eval:
dataset_sink_mode: False
dataset:
type: LMDBDataset
data_dir: path/to/datadir/validation/
# label_files: /data/ocr_datasets/ic15/word_recognition/rec_gt_train.txt
sample_ratios: [1.0]
dataset_root: path/to/data_lmdb_release/
data_dir: validation/
# label_files: # not required when using LMDBDataset
sample_ratios: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
Expand Down
54 changes: 30 additions & 24 deletions mindocr/data/rec_lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class LMDBDataset(BaseDataset):
The annotaiton format is required to aligned to paddle, which can be done using the `converter.py` script.

Args:
is_train:
data_dir:
shuffle, Optional, if not given, shuffle = is_train
is_train: whether the dataset is for training
data_dir: data root directory for lmdb dataset(s)
shuffle: Optional, if not given, shuffle = is_train
transform_pipeline: list of dict, key - transform class name, value - a dict of param config.
e.g., [{'DecodeImage': {'img_mode': 'BGR', 'channel_first': False}}]
- if None, default transform pipeline for text detection will be taken.
Expand All @@ -43,7 +43,7 @@ class LMDBDataset(BaseDataset):
def __init__(self,
is_train: bool = True,
data_dir: str = '',
sample_ratio: Union[List, float] = 1.0,
sample_ratio: float = 1.0,
shuffle: bool = None,
transform_pipeline: List[dict] = None,
output_columns: List[str] = None,
Expand All @@ -55,8 +55,7 @@ def __init__(self,
assert isinstance(shuffle, bool), f'type error of {shuffle}'
shuffle = shuffle if shuffle is not None else is_train

sample_ratio = sample_ratio[0] if isinstance(sample_ratio, list) else sample_ratio
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
self.lmdb_sets = self.load_list_of_hierarchical_lmdb_dataset(data_dir)
self.data_idx_order_list = self.get_dataset_idx_orders(sample_ratio, shuffle)

# create transform
Expand Down Expand Up @@ -88,23 +87,31 @@ def __init__(self,
else:
raise ValueError(f'Key {k} does not exist in data (available keys: {_data.keys()}). Please check the name or the completeness transformation pipeline.')

def load_list_of_hierarchical_lmdb_dataset(self, data_dir):
if isinstance(data_dir, str):
results = self.load_hierarchical_lmdb_dataset(data_dir)
elif isinstance(data_dir, list):
results = {}
for sub_data_dir in data_dir:
start_idx = len(results)
lmdb_sets = self.load_hierarchical_lmdb_dataset(sub_data_dir, start_idx)
results.update(lmdb_sets)
else:
results = {}

return results

def load_hierarchical_lmdb_dataset(self, data_dir, start_idx=0):

lmdb_sets = {}
dataset_idx = start_idx
for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
if not dirnames:
env = lmdb.open(
dirpath,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
for rootdir, dirs, _ in os.walk(data_dir + '/'):
if not dirs:
env = lmdb.open(rootdir, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
txn = env.begin(write=False)
data_size = int(txn.get('num-samples'.encode()))
lmdb_sets[dataset_idx] = {
"dirpath":dirpath,
"rootdir":rootdir,
"env":env,
"txn":txn,
"data_size":data_size
Expand Down Expand Up @@ -135,24 +142,23 @@ def get_dataset_idx_orders(self, sample_ratio, shuffle):

return data_idx_order_list

def get_lmdb_sample_info(self, txn, index):
label_key = 'label-%09d'.encode() % index
def get_lmdb_sample_info(self, txn, idx):
label_key = 'label-%09d'.encode() % idx
label = txn.get(label_key)
if label is None:
return None
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % index
img_key = 'image-%09d'.encode() % idx
imgbuf = txn.get(img_key)
return imgbuf, label

def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
file_idx)
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[int(lmdb_idx)]['txn'],
int(file_idx))
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
random_idx = np.random.randint(self.__len__())
return self.__getitem__(random_idx)

data = {
"img_lmdb": sample_info[0],
Expand Down
76 changes: 37 additions & 39 deletions mindocr/models/backbones/rec_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,40 @@
__all__ = ['RecResNet', 'rec_resnet34']


class ConvBNLayer(nn.Cell):
class ConvNormLayer(nn.Cell):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=False):
super(ConvBNLayer, self).__init__()
super(ConvNormLayer, self).__init__()

self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2d(
self.pool2d_avg = nn.AvgPool2d(
kernel_size=stride, stride=stride, pad_mode="same")
self._conv = nn.Conv2d(
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1 if is_vd_mode else stride,
pad_mode='pad',
padding=(kernel_size - 1) // 2,
)
self._batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=1e-5, momentum=0.9,
self.norm = nn.BatchNorm2d(num_features=out_channels, eps=1e-5, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
self._act = nn.ReLU()
self.act_func = nn.ReLU()
self.act = act

def construct(self, inputs):
def construct(self, x):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
x = self.pool2d_avg(x)
y = self.conv(x)
y = self.norm(y)
if self.act:
y = self._act(y)
y = self.act_func(y)
return y


Expand All @@ -53,20 +52,20 @@ def __init__(self,
):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
self.conv0 = ConvNormLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act=True)
self.conv1 = ConvBNLayer(
self.conv1 = ConvNormLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=False)

if not shortcut:
self.short = ConvBNLayer(
self.short = ConvNormLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
Expand All @@ -76,14 +75,14 @@ def __init__(self,
self.shortcut = shortcut
self.relu = nn.ReLU()

def construct(self, inputs):
y = self.conv0(inputs)
def construct(self, x):
y = self.conv0(x)
conv1 = self.conv1(y)

if self.shortcut:
short = inputs
short = x
else:
short = self.short(inputs)
short = self.short(x)
y = short + conv1
y = self.relu(y)
return y
Expand All @@ -97,65 +96,64 @@ def __init__(self, in_channels=3, layers=34, **kwargs):
self.out_channels = 512
self.layers = layers
supported_layers = [34]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
assert layers in supported_layers, "only support {} layers but input layer is {}".format(
supported_layers, layers)

depth = [3, 4, 6, 3]
num_channels = [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]

self.conv1_1 = ConvBNLayer(
self.conv1_1 = ConvNormLayer(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=1,
act=True)
self.conv1_2 = ConvBNLayer(
self.conv1_2 = ConvNormLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act=True)
self.conv1_3 = ConvBNLayer(
self.conv1_3 = ConvNormLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act=True)
self.pool2d_max = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.maxpool2d_1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')

self.block_list = []
for block in range(len(depth)):
for block_id in range(len(depth)):
shortcut = False
for i in range(depth[block]):
if i == 0 and block != 0:
for i in range(depth[block_id]):
if i == 0 and block_id != 0:
stride = (2, 1)
else:
stride = (1, 1)

is_first = block_id == i == 0
in_channels = num_channels[block_id] if i == 0 else num_filters[block_id]
basic_block = BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
in_channels=in_channels,
out_channels=num_filters[block_id],
stride=stride,
shortcut=shortcut,
if_first=block == i == 0
if_first=is_first
)
shortcut = True
self.block_list.append(basic_block)
self.out_channels = num_filters[block]

self.block_list = nn.SequentialCell(self.block_list)
self.out_pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')
self.maxpool2d_2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')

def construct(self, inputs):
y = self.conv1_1(inputs)
def construct(self, x):
y = self.conv1_1(x)
y = self.conv1_2(y)
y = self.conv1_3(y)
y = self.pool2d_max(y)
y = self.maxpool2d_1(y)
y = self.block_list(y)
y = self.out_pool(y)
y = self.maxpool2d_2(y)
return [y]

# TODO: load pretrained weight in build_backbone or use a unify wrapper to load
Expand All @@ -169,4 +167,4 @@ def rec_resnet34(pretrained: bool = True, **kwargs):
if pretrained:
raise NotImplementedError

return model
return model