Skip to content

fix ut #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tests/ut/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import glob
import yaml
from mindcv.utils.download import DownLoad

def gen_dummpy_data(task):
# prepare dummy images
data_dir = "data/Canidae"
dataset_url = (
"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
)
if not os.path.exists(data_dir):
DownLoad().download_and_extract_archive(dataset_url, "./")

# prepare dummy labels
for split in ['train', 'val']:
label_path = f'tests/st/dummy_labels/{task}_{split}_gt.txt'
image_dir = f'{data_dir}/{split}/dogs'
new_label_path = f'data/Canidae/{split}/{task}_gt.txt'
img_paths = glob.glob(os.path.join(image_dir, '*.JPEG'))
#print(len(img_paths))
with open(new_label_path, 'w') as f_w:
with open(label_path, 'r') as f_r:
i = 0
for line in f_r:
_, label = line.strip().split('\t')
#print(i)
img_name = os.path.basename(img_paths[i])
new_img_label = img_name + '\t' + label
f_w.write(new_img_label + '\n')
i += 1
print(f'Dummpy annotation file is generated in {new_label_path}')

return data_dir

def update_config_for_CI(config_fp, task, val_while_train=False):
with open(config_fp) as fp:
config = yaml.safe_load(fp)
config['system']['distribute'] = False
config['system']['val_while_train'] = val_while_train
#if 'common' in config:
# config['batch_size'] = 8
config['train']['dataset_sink_mode'] = False

config['train']['dataset']['dataset_root'] = 'data/Canidae/'
config['train']['dataset']['data_dir'] = 'train/dogs'
config['train']['dataset']['label_file'] = f'train/{task}_gt.txt'
config['train']['dataset']['sample_ratio'] = 0.1 # TODO: 120 training samples in total, don't be larger than batchsize after sampling
config['train']['loader']['num_workers'] = 1 # github server only support 2 workers at most
#if config['train']['loader']['batch_size'] > 120:
config['train']['loader']['batch_size'] = 2 # to save memory
config['train']['loader']['max_rowsize'] = 16 # to save memory
config['train']['loader']['prefetch_size'] = 2 # to save memory
if 'common' in config:
config['common']['batch_size'] = 2
if 'batch_size' in config['loss']:
config['loss']['batch_size'] = 2

config['eval']['dataset']['dataset_root'] = 'data/Canidae/'
config['eval']['dataset']['data_dir'] = 'val/dogs'
config['eval']['dataset']['label_file'] = f'val/{task}_gt.txt'
config['eval']['dataset']['sample_ratio'] = 0.1
config['eval']['loader']['num_workers'] = 1 # github server only support 2 workers at most
config['eval']['loader']['batch_size'] = 1
config['eval']['loader']['max_rowsize'] = 16 # to save memory
config['eval']['loader']['prefetch_size'] = 2 # to save memory

config['eval']['ckpt_load_path'] = os.path.join(config['train']['ckpt_save_dir'], 'best.ckpt')

config['scheduler']['num_epochs'] = 2
config['scheduler']['warmup_epochs'] = 1
config['scheduler']['decay_epochs'] = 1

dummpy_config_fp =os.path.join('tests/st', os.path.basename(config_fp.replace('.yaml', '_dummpy.yaml')))
with open(dummpy_config_fp, 'w') as f:
args_text = yaml.safe_dump(config, default_flow_style=False, sort_keys=False)
f.write(args_text)
print('Genearted yaml: ')
print(args_text)

return dummpy_config_fp


131 changes: 14 additions & 117 deletions tests/ut/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from _common import gen_dummpy_data, update_config_for_CI

import sys
sys.path.append('.')

Expand All @@ -14,32 +16,28 @@
from mindocr.data.transforms.transforms_factory import transforms_dbnet_icdar15
from mindocr.data.rec_dataset import RecDataset
from mindspore import load_checkpoint, load_param_into_net

from mindocr.utils.visualize import show_img, draw_bboxes, show_imgs, recover_image
from mindcv.utils.download import DownLoad


@pytest.mark.parametrize('task', ['det', 'rec'])
#@pytest.mark.parametrize('phase', ['train', 'eval'])
def test_build_dataset(task='det', phase='train', verbose=True, visualize=False):
if task == 'rec':
yaml_fp = 'configs/rec/test.yaml'
else:
yaml_fp = 'configs/det/test.yaml'
def test_build_dataset(task, phase='train', verbose=False, visualize=False):
# modify ocr predefined yaml for minimum test
if task == 'det':
config_fp = 'configs/det/dbnet/db_r50_icdar15.yaml'
elif task=='rec':
config_fp = 'configs/rec/crnn/crnn_icdar15.yaml'

data_dir = gen_dummpy_data(task)
yaml_fp = update_config_for_CI(config_fp, task)

with open(yaml_fp) as fp:
cfg = yaml.safe_load(fp)

if task == 'rec':
from mindocr.data.transforms.rec_transforms import RecCTCLabelEncode
dict_path = cfg['common']['character_dict_path']
# read dict path and get class nums
rec_info = RecCTCLabelEncode(character_dict_path=dict_path)
#config['model']['head']['out_channels'] = num_classes
print('=> num classes (valid chars + special tokens): ', rec_info.num_classes)

dataset_config = cfg[phase]['dataset']
loader_config = cfg[phase]['loader']


dl = build_dataset(dataset_config, loader_config, is_train=(phase=='train'))
#dl.repeat(300)
num_batches = dl.get_dataset_size()
Expand Down Expand Up @@ -67,115 +65,14 @@ def test_build_dataset(task='det', phase='train', verbose=True, visualize=False)
polys = batch['polys'][0].asnumpy()
img_polys = draw_bboxes(recover_image(img), polys)
show_img(img_polys)

start = time.time()

WU = 2
tot = sum(times[WU:]) # skip warmup
mean = tot / (num_tries-WU)
print('Avg batch loading time: ', mean)

#@pytest.mark.parametrize('model_name', all_model_names)
def test_det_dataset():
data_dir = '/data/ocr_datasets/ic15/text_localization/train'
annot_file = '/data/ocr_datasets/ic15/text_localization/train/train_icdar15_label.txt'
transform_pipeline = transforms_dbnet_icdar15(phase='train')
ds = DetDataset(is_train=True, data_dir=data_dir, label_file=annot_file, sample_ratio=0.5, transform_pipeline=transform_pipeline, shuffle=False)

print('num data: ', len(ds))
for i in [223]:
data_tuple = ds.__getitem__(i)

# recover data from tuple to dict
data = {k:data_tuple[i] for i, k in enumerate(ds.get_column_names())}

print(data.keys())
#print(data['image'])
print(data['img_path'])
print(data['image'].shape)
print(data['polys'])
print(data['texts'])
#print(data['mask'])
#print(data['threshold_map'])
#print(data['threshold_mask'])
for k in data:
print(k, data[k])
if isinstance(data[k], np.ndarray):
print(data[k].shape)

#show_img(data['image'], 'BGR')
#result_img1 = draw_bboxes(data['ori_image'], data['polys'])
img_polys = draw_bboxes(recover_image(data['image']), data['polys'])
#show_img(result_img2, show=False, save_path='/data/det_trans.png')

mask_polys= draw_bboxes(data['shrink_mask'], data['polys'])
thrmap_polys= draw_bboxes(data['threshold_map'], data['polys'])
thrmask_polys= draw_bboxes(data['threshold_mask'], data['polys'])
show_imgs([img_polys, mask_polys, thrmap_polys, thrmask_polys], show=False, save_path='/data/ocr_ic15_debug2.png')


np.savez('./det_db_label_samples.npz',
image=data['image'],
polys=data['polys'],
texts=data['texts'],
ignore_tags=data['ignore_tags'],
shrink_map=data['shrink_map'],
shrink_mask=data['shrink_mask'],
threshold_map=data['threshold_map'],
threshold_mask=data['threshold_mask'],
)

# TODO: check transformed image and label correctness

def test_rec_dataset(visualize=True):

yaml_fp = 'configs/rec/crnn_icdar15.yaml'
with open(yaml_fp) as fp:
cfg = yaml.safe_load(fp)

data_dir = '/Users/Samit/Data/datasets/ic15/rec/ch4_training_word_images_gt'
label_path = '/Users/Samit/Data/datasets/ic15/rec/rec_gt_train.txt'
ds = RecDataset(is_train=True,
data_dir=data_dir,
label_file=label_path,
sample_ratio=1.0,
shuffle = False,
transform_pipeline = cfg['train']['dataset']['transform_pipeline'],
output_columns = None)
# visualize to check correctness
from mindocr.utils.visualize import show_img, draw_bboxes, show_imgs, recover_image
print('num data: ', len(ds))
for i in [3]:
data_tuple = ds.__getitem__(i)
print('output columns: ', ds.get_column_names())
# recover data from tuple to dict
data = {k:data_tuple[i] for i, k in enumerate(ds.get_column_names())}

print(data['img_path'])
print(data['image'].shape)
print('text: ', data['text'])
print(f'\t Shapes: ', {k: data[k].shape for k in data if isinstance(data[k], np.ndarray)})
print('label: ', data['label_ace'])
print('label_ace: ', data['label_ace'])
image = recover_image(data['image'])
show_img(image, show=True) #, save_path='/data/ocr_ic15_debug2.png')

'''
dl = build_dataset(
cfg['train']['dataset'],
cfg['train']['loader'],
is_train=True)
batch = next(dl.create_dict_iterator())
print(len(batch))
for item in batch:
print(item.shape)
'''


if __name__ == '__main__':
#test_build_dataset(task='det', phase='eval', visualize=True)
#test_build_dataset(task='det', phase='train', visualize=False)
test_build_dataset(task='rec', phase='train', visualize=False)
#test_build_dataset(task='rec')
#test_det_dataset()
#test_rec_dataset()
42 changes: 11 additions & 31 deletions tests/ut/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
all_model_names = mindocr.list_models()
print('Registered models: ', all_model_names)

all_yamls = glob.glob('configs/*/*.yaml')
#all_yamls = glob.glob('configs/*/*.yaml')
all_yamls = ['configs/det/dbnet/db_r50_icdar15.yaml', 'configs/rec/crnn/crnn_icdar15.yaml']
print('All config yamls: ', all_yamls)

def _infer_dummy(model, task='det', verbose=True):
import mindspore as ms
import time
import numpy as np

print(task)

bs = 8
if task=='rec':
c, h, w = 3, 32, 100
Expand Down Expand Up @@ -65,18 +68,18 @@ def test_model_by_name(model_name):


@pytest.mark.parametrize('yaml_fp', all_yamls)
def test_model_by_yaml(yaml_fp='configs/det/dbnet/db_r50_icdar15.yaml'):
def test_model_by_yaml(yaml_fp):
print(yaml_fp)
with open(yaml_fp) as fp:
config = yaml.safe_load(fp)

task = yaml_fp.split('/')[-2]
task = yaml_fp.split('/')[1]

if task == 'rec':
from mindocr.postprocess.rec_postprocess import CTCLabelDecode
from mindocr.postprocess.rec_postprocess import RecCTCLabelDecode
dict_path = config['common']['character_dict_path']
# read dict path and get class nums
rec_info = CTCLabelDecode(character_dict_path=dict_path)
rec_info = RecCTCLabelDecode(character_dict_path=dict_path)
num_classes = len(rec_info.character)
config['model']['head']['out_channels'] = num_classes
print('num characters: ', num_classes)
Expand All @@ -85,33 +88,9 @@ def test_model_by_yaml(yaml_fp='configs/det/dbnet/db_r50_icdar15.yaml'):
model = build_model(model_config)
_infer_dummy(model, task=task)

'''
model_config = {
"backbone": {
'name': 'det_resnet50',
'pretrained': False
},
"neck": {
"name": 'FPN',
"out_channels": 256,
},
"head": {
"name": 'ConvHead',
"out_channels": 2,
"k": 50
}

}
'''

''' TODO: check loading
ckpt_path = None
if ckpt_path is not None:
param_dict = load_checkpoint(os.path.join(path, os.path.basename(default_cfg['url'])))
load_param_into_net(model, param_dict)
'''

if __name__ == '__main__':
test_model_by_yaml(all_yamls[1])
'''
import argparse
parser = argparse.ArgumentParser(description='model config', add_help=False)
parser.add_argument('-c', '--config', type=str, default='configs/det/dbnet/db_r50_icdar15.yaml',
Expand All @@ -121,3 +100,4 @@ def test_model_by_yaml(yaml_fp='configs/det/dbnet/db_r50_icdar15.yaml'):
#test_backbone()
#test_model_by_name('dbnet_r50')
test_model_by_yaml(args.config)
'''
30 changes: 0 additions & 30 deletions tests/ut/test_postprocess.py

This file was deleted.