Skip to content

Improve mindspore data process pipeline setting, set default values f… #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions configs/det/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,14 @@ optimizer:
loss_scale: 1.0

train:
dataset_sink_mode: True
dataset_sink_mode: False
ckpt_save_dir: './tmp_det'
dataset:
type: DetDataset
#dataset_root: /data/ocr_datasets
dataset_root: /Users/Samit/Data/datasets
dataset_root: /data/ocr_datasets
data_dir: ic15/det/train/ch4_training_images
label_file: ic15/det/train/det_gt.txt
sample_ratio: 1.0
sample_ratio: 0.5
shuffle: True
transform_pipeline:
- DecodeImage:
Expand Down Expand Up @@ -103,24 +102,23 @@ train:
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_keys: ['image', 'shrink_map', 'shrink_mask', 'threshold_map', 'threshold_mask'] #'img_path']
output_keys: ['image', 'shrink_map', 'shrink_mask', 'threshold_map', 'threshold_mask']
#output_keys: ['image'] # for debug op performance
num_keys_to_net: 1 # num inputs for network forward func in output_keys
# keys_for_loss: 4 # num labels for loss func

loader:
shuffle: True # TODO: tbc
batch_size: 20
drop_remainder: False
max_rowsize: 20
num_workers: 1 # TODO: large value may lead to OOM
shuffle: True
batch_size: 16
num_workers: 8 # TODO: large value may lead to OOM
prefetch_size: 16 #
drop_remainder: True

eval:
dataset_sink_mode: False
dataset:
type: DetDataset
#dataset_root: /data/ocr_datasets/ # needed for openi
dataset_root: /Users/Samit/Data/datasets
dataset_root: /data/ocr_datasets/
data_dir: ic15/det/test/ch4_test_images
label_file: ic15/det/test/det_gt.txt
sample_ratio: 1.0
Expand All @@ -147,12 +145,4 @@ eval:
shuffle: False
batch_size: 1 # TODO: due to dynamic shape of polygons (num of boxes varies), BS has to be 1
drop_remainder: False
max_rowsize: 20
num_workers: 1

modelarts: # TODO: for running on modelarts or openi. Not making effect currently.
enable_modelarts: False
data_url: /cache/data/ # path to dataset
multi_data_url: /cache/data/ # path to multi dataset
ckpt_url: /cache/output/ # pretrained model path
train_url: /cache/output/ # model save folder
53 changes: 39 additions & 14 deletions mindocr/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def build_dataset(
**kwargs,
):
'''
Build dataset

Args:
dataset_config (dict): dataset reading and processing configuartion containing keys:
- type: dataset type, 'DetDataset', 'RecDataset'
Expand All @@ -34,8 +36,13 @@ def build_dataset(

Return:
data_loader (Dataset): dataloader to generate data batch
'''

Notes:
- The main data process pipeline in MindSpore contains 3 parts: 1) load data files and generate source dataset, 2) perform per-data-row mapping such as image augmentation, 3) generate batch and apply batch mapping.
- Each of the three steps supports multiprocess. Detailed machenism can be seen in https://www.mindspore.cn/docs/zh-CN/r2.0.0-alpha/api_python/mindspore.dataset.html
- A data row is a data tuple item containing multiple elements such as (image_i, mask_i, label_i). A data column corresponds to an element in the tuple like 'image', 'label'.
- The total number of `num_parallel_workers` used for data loading and processing should not be larger than the maximum threads of the CPU. Otherwise, it will lead to resource competing overhead. Especially for distributed training, `num_parallel_workers` should not be too large to avoid thread competition.
'''
## check and process dataset_root, data_dir, and label_file.
if 'dataset_root' in dataset_config:
if isinstance(dataset_config['data_dir'], str):
Expand All @@ -47,7 +54,7 @@ def build_dataset(
if isinstance(dataset_config['label_file'], str):
dataset_config['label_file'] = os.path.join(dataset_config['dataset_root'], dataset_config['label_file'])
else:
dataset_config['label_file'] = [os.path.join(dataset_config['dataset_root'], lf) for lf in dataset_confg['label_file']]
dataset_config['label_file'] = [os.path.join(dataset_config['dataset_root'], lf) for lf in dataset_config['label_file']]

# build datasets
dataset_class_name = dataset_config.pop('type')
Expand All @@ -61,27 +68,45 @@ def build_dataset(
dataset_column_names = dataset.get_column_names()
print('==> Dataset columns: \n\t', dataset_column_names)

# TODO: the optimal value for prefetch. * num_workers?
#ms.dataset.config.set_prefetch_size(int(loader_config['batch_size']))
#print('prfectch size:', ms.dataset.config.get_prefetch_size())
# TODO: find optimal setting automatically according to num of CPU cores
num_workers = loader_config.get("num_workers", 8) # Number of subprocesses used to fetch the dataset/map data row/gen batch in parallel
prefetch_size = loader_config.get("prefetch_size", 16) # the length of the cache queue in the data pipeline for each worker, used to reduce waiting time. Larger value leads to more memory consumption. Default: 16
max_rowsize = loader_config.get("max_rowsize", 64) # MB of shared memory between processes to copy data

ms.dataset.config.set_prefetch_size(prefetch_size)
#print('Prefetch size: ', ms.dataset.config.get_prefetch_size())

# TODO: config multiprocess and shared memory
# auto tune num_workers, prefetch. (This conflicts the profiler)
#ms.dataset.config.set_autotune_interval(5)
#ms.dataset.config.set_enable_autotune(True, "./dataproc_autotune_out")

# 1. generate source dataset (source w.r.t. the dataset.map pipeline) based on python callable numpy dataset in parallel
ds = ms.dataset.GeneratorDataset(
dataset,
column_names=dataset_column_names,
num_parallel_workers=loader_config['num_workers'],
num_parallel_workers=num_workers,
num_shards=num_shards,
shard_id=shard_id,
python_multiprocessing=True,
max_rowsize =loader_config['max_rowsize'],
python_multiprocessing=True, # keep True to improve performace for heavy computation.
max_rowsize =max_rowsize,
shuffle=loader_config['shuffle'],
)

# TODO: set default value for drop_remainder and max_rowsize
dataloader = ds.batch(loader_config['batch_size'],
drop_remainder=loader_config['drop_remainder'],
max_rowsize=loader_config['max_rowsize'],
#num_parallel_workers=loader_config['num_workers'],
# 2. per-data-item mapping (high-performance transformation)
#ds = ds.map(operations=transform_list, input_columns=['image', 'label'], num_parallel_workers=8, python_multiprocessing=True)


# 3. get batch of dataset by collecting batch_size consecutive data rows and apply batch operations
drop_remainder = loader_config.get('drop_remainder', is_train)
if is_train and drop_remainder == False:
print('WARNING: drop_remainder should be True for training, otherwise the last batch may lead to training fail.')
dataloader = ds.batch(
loader_config['batch_size'],
drop_remainder=drop_remainder,
num_parallel_workers=min(num_workers, 2), # set small value since it is lite computation.
#input_columns=input_columns,
#output_columns=batch_column,
#per_batch_map=per_batch_map, # uncommet to use inner-batch transformation
)

#steps_pre_epoch = dataset.get_dataset_size()
Expand Down
2 changes: 2 additions & 0 deletions tests/st/test_train_eval_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def test_train_eval(task):
out, err = p.communicate()
# assert ret==0, 'Validation fails'
print(out)

p.kill()
'''
if check_acc:
res = out.decode()
Expand Down
7 changes: 4 additions & 3 deletions tests/ut/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
@pytest.mark.parametrize('task', ['det', 'rec'])
#@pytest.mark.parametrize('phase', ['train', 'eval'])
def test_build_dataset(task='det', phase='train', verbose=True, visualize=False):
# TODO: download sample test data automatically
#data_dir = '/data/ocr_datasets/ic15/text_localization/train'
#annot_file = '/data/ocr_datasets/ic15/text_localization/train/train_icdar15_label.txt'
'''
Expand Down Expand Up @@ -59,9 +58,9 @@ def test_build_dataset(task='det', phase='train', verbose=True, visualize=False)
'''

if task == 'rec':
yaml_fp = 'configs/rec/crnn_icdar15.yaml'
yaml_fp = 'configs/rec/test.yaml'
else:
yaml_fp = 'configs/det/db_test.yaml'
yaml_fp = 'configs/det/test.yaml'

with open(yaml_fp) as fp:
cfg = yaml.safe_load(fp)
Expand All @@ -77,9 +76,11 @@ def test_build_dataset(task='det', phase='train', verbose=True, visualize=False)
dataset_config = cfg[phase]['dataset']
loader_config = cfg[phase]['loader']


dl = build_dataset(dataset_config, loader_config, is_train=(phase=='train'))
num_batches = dl.get_dataset_size()

ms.set_context(mode=0)
#batch = next(dl.create_tuple_iterator())
num_tries = 3
start = time.time()
Expand Down
4 changes: 2 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main(cfg):
rank_id = None

set_seed(cfg.system.seed)
cv2.setNumThreads(2) # TODO: by default, num threads = num cpu cores
#cv2.setNumThreads(2) # TODO: by default, num threads = num cpu cores
is_main_device = rank_id in [None, 0]

# train pipeline
Expand Down Expand Up @@ -136,7 +136,7 @@ def main(cfg):
f.write(args_text)

# training
loss_monitor = LossMonitor(10) #num_batches // 10)
loss_monitor = LossMonitor(min(num_batches // 10, 100))
time_monitor = TimeMonitor()

model = ms.Model(train_net)
Expand Down