Skip to content

Add comments to dataset_root and data_dir #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions mindocr/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ def build_dataset(
Args:
dataset_config (dict): dataset reading and processing configuartion containing keys:
- type: dataset type, 'DetDataset', 'RecDataset'
- data_dir Union[str, List]: folder to the dataset.
- label_file (optional for recognition): file path(s) to the annotation file
- transform_pipeline (list[dict]): config dict for image and label transformation
- dataset_root (str): the root directory to store the (multiple) dataset(s)
- data_dir (Union[str, List[str]]): directory to the data, which is a subfolder path related to `dataset_root`. For multiple datasets, it is a list of subfolder paths.
- label_file (Union[str, List[str]]): file path to the annotation related to the `dataset_root`. For multiple datasets, it is a list of relative file paths.
- transform_pipeline (list[dict]): each element corresponds to a transform operation on image and/or label

loader_config (dict): dataloader configuration containing keys:
- batch_size: batch size for data loader
- drop_remainder: whether to drop the data in the last batch when the total of data can not be divided by the batch_size
Expand All @@ -33,13 +35,11 @@ def build_dataset(
Return:
data_loader (Dataset): dataloader to generate data batch
'''
# build datasets
dataset_class_name = dataset_config.pop('type')
assert dataset_class_name in supported_dataset_types, "Invalid dataset name"
## convert data_dir and to abs path. TODO: do it inside dataset class init?

## check and process dataset_root, data_dir, and label_file.
if 'dataset_root' in dataset_config:
if isinstance(dataset_config['data_dir'], str):
dataset_config['data_dir'] = os.path.join(dataset_config['dataset_root'], dataset_config['data_dir'])
dataset_config['data_dir'] = os.path.join(dataset_config['dataset_root'], dataset_config['data_dir']) # to absolute path
else:
dataset_config['data_dir'] = [os.path.join(dataset_config['dataset_root'], dd) for dd in dataset_config['data_dir']]

Expand All @@ -49,11 +49,11 @@ def build_dataset(
else:
dataset_config['label_file'] = [os.path.join(dataset_config['dataset_root'], lf) for lf in dataset_confg['label_file']]

# get dataset class
# build datasets
dataset_class_name = dataset_config.pop('type')
assert dataset_class_name in supported_dataset_types, "Invalid dataset name"
dataset_class = eval(dataset_class_name)

#print('dataset config', dataset_config)

dataset_args = dict(is_train=is_train, **dataset_config)
dataset = dataset_class(**dataset_args)

Expand Down
7 changes: 4 additions & 3 deletions mindocr/models/heads/rec_ctc_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def __init__(self,
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats

if weight_init == "crnn_customised":
weight_init = crnn_head_initialization(in_channels)

if bias_init == "crnn_customised":
bias_init = crnn_head_initialization(in_channels)

Expand Down Expand Up @@ -75,7 +75,8 @@ def construct(self, x):
h = self.dense2(h)

if not self.training:
h = ops.softmax(h, axis=2)
#h = ops.softmax(h, axis=2) # not support on ms 1.8.1
h = ops.Softmax(axis=2)(h)

pred = {'head_out': h}
return pred
Expand Down
2 changes: 2 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def main(cfg):
print(f'INFO: datasets found: {os.listdir(dataset_root)} \n'
f'INFO: dataset_root is changed to {dataset_root}'
)
# update dataset root dir to cache
assert 'dataset_root' in config['train']['dataset'], f'`dataset_root` must be provided in the yaml file for training on ModelArts or OpenI, but not found in {yaml_fp}. Please add `dataset_root` to `train:dataset` and `eval:dataset` in the yaml file'
config.train.dataset.dataset_root = dataset_root
config.eval.dataset.dataset_root = dataset_root

Expand Down