Skip to content

Improve postprocess APIs to receive more data info for processing; add batch size refinement in validation while training #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/det/dbnet/db_r50_icdar15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ eval:
# the order of the dataloader list, matching the network input and the labels for evaluation
output_columns: [ 'image', 'polys', 'ignore_tags' ]
num_columns_to_net: 1 # num inputs for network forward func
# num_keys_of_labels: 2 # num labels
num_columns_of_labels: 2 # num labels

loader:
shuffle: False
Expand Down
6 changes: 3 additions & 3 deletions configs/rec/crnn/crnn_icdar15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ train:
std : [127.0, 127.0, 127.0]
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_columns: ['image', 'text_seq'] #, 'length'] #'img_path']
output_columns: ['image', 'text_seq']
num_columns_to_net: 1 # num inputs for network forward func in output_columns
#keys_for_loss: 4 # num labels for loss func

loader:
shuffle: True # TODO: tbc
Expand Down Expand Up @@ -142,8 +141,9 @@ eval:
std : [127.0, 127.0, 127.0]
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_columns: ['image', 'text_padded', 'text_length'] # TODO return text string padding w/ fixed length, and a scaler to indicate the length
output_columns: ['image', 'text_padded', 'text_length', 'img_path'] # TODO return text string padding w/ fixed length, and a scaler to indicate the length
num_columns_to_net: 1 # num inputs for network forward func
num_columns_of_labels: 2 # num of label columns

loader:
shuffle: False
Expand Down
1 change: 1 addition & 0 deletions configs/rec/crnn/crnn_resnet34.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ eval:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_columns: ['image', 'text_padded', 'text_length'] # TODO return text string padding w/ fixed length, and a scaler to indicate the length
num_columns_to_net: 1 # num inputs for network forward func
num_columns_of_labels: 2 # num labels

loader:
shuffle: False # TODO: tbc
Expand Down
1 change: 1 addition & 0 deletions mindocr/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def build_dataset(
drop_remainder = loader_config.get('drop_remainder', is_train)
if is_train and drop_remainder == False:
print('WARNING: drop_remainder should be True for training, otherwise the last batch may lead to training fail in Graph mode')

if not is_train:
if drop_remainder:
print("WARNING: drop_remainder is forced to be False for evaluation to include the last batch for accurate evaluation." )
Expand Down
2 changes: 1 addition & 1 deletion mindocr/postprocess/det_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, binary_thresh=0.3, box_thresh=0.7, max_candidates=1000, expan
self._name = pred_name
self._names = {'binary': 0, 'thresh': 1, 'thresh_binary': 2}

def __call__(self, pred):
def __call__(self, pred, **kwargs):
"""
pred (Union[Tensor, Tuple[Tensor], np.ndarray]):
binary: text region segmentation map, with shape (N, 1, H, W)
Expand Down
54 changes: 34 additions & 20 deletions mindocr/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@
class Evaluator:
"""
Args:
metric:
network: network
dataloader : data loader to generate batch data, where the data columns in a batch are defined by the transform pipeline and `output_columns`.
loss_fn: loss function
postprocessor: post-processor
metrics: metrics to evaluate network performance
num_columns_to_net: number of inputs to the network in the dataset output columns. Default is 1 for the first column is image.
num_columns_of_labels: number of labels in the dataset output columns. Default is None assuming the columns after image (data[1:]) are labels.
If not None, the num_columns_of_labels columns after image (data[1:1+num_columns_of_labels]) are labels, and the remaining columns are additional info like image_path.
"""

def __init__(self, network, dataloader, loss_fn=None, postprocessor=None, metrics=None, num_epochs_train=-1, visualize=False, verbose=False,
def __init__(self, network, dataloader, loss_fn=None, postprocessor=None, metrics=None,
num_columns_to_net=1, num_columns_of_labels=None,
visualize=False, verbose=False,
**kwargs):
self.net = network
self.postprocessor = postprocessor
Expand All @@ -37,17 +46,20 @@ def __init__(self, network, dataloader, loss_fn=None, postprocessor=None, metric
if loss_fn is not None:
eval_loss = True
self.loss_fn = loss_fn
# TODO: add support for computing evaluation loss
assert eval_loss == False, 'not impl'

# create iterator
self.iterator = dataloader.create_tuple_iterator(num_epochs=num_epochs_train, output_numpy=False, do_copy=False)
self.iterator = dataloader.create_tuple_iterator(num_epochs=-1, output_numpy=False, do_copy=False)
self.num_batches_eval = dataloader.get_dataset_size()

def eval(self, num_columns_to_net=1, num_keys_of_labels=None):
# dataset output columns
self.num_inputs = num_columns_to_net
self.num_labels = num_columns_of_labels
assert self.num_inputs==1, 'Only num_columns_to_net=1 (single input to network) is needed and supported for current networks.'

def eval(self):
"""
Args:
dataloader (Dataset): data iterator which generates tuple of Tensor defined by the transform pipeline and 'output_columns'
"""
eval_res = {}

Expand All @@ -56,25 +68,26 @@ def eval(self, num_columns_to_net=1, num_keys_of_labels=None):
m.clear()

for i, data in tqdm(enumerate(self.iterator), total=self.num_batches_eval):
# start = time.time()
# TODO: if network input is not just an image.
img = data[0] # ms.Tensor(batch[0])
gt = data[1:] # ground truth, (polys, ignore_tags) for det,
#print(i, img.shape, img.sum())

net_preds = self.net(img)
# net_inputs = data[:num_columns_to_net]
# gt = data[num_columns_to_net:] # ground truth
# preds = self.net(*net_inputs)
# print('net predictions', preds)

inputs = data[:self.num_inputs] # [imgs]
gt = data[self.num_inputs:] if self.num_labels is None else data[self.num_inputs: self.num_inputs+self.num_labels]

net_preds = self.net(*inputs)

if self.postprocessor is not None:
preds = self.postprocessor(net_preds) # {'polygons':, 'scores':} for text det
# additional info such as image path, original image size, pad shape, extracted in data processing
meta_info = data[(self.num_inputs+self.num_labels):] if (self.num_labels is not None) else []
data_info = {'labels': gt, 'img_shape': inputs[0].shape, 'meta_info': meta_info}
preds = self.postprocessor(net_preds, **data_info)

# metric internal update
for m in self.metrics:
m.update(preds, gt)

# visualize
if self.verbose:
print('Eval data info: ', data_info)

if self.visualize:
img = img[0].asnumpy()
assert ('polys' in preds) or ('polygons' in preds), 'Only support detection'
Expand All @@ -98,7 +111,6 @@ class EvalSaveCallback(Callback):
Args:
network (nn.Cell): network (without loss)
loader (Dataset): dataloader
saving_config (dict):
"""

def __init__(self,
Expand All @@ -112,6 +124,8 @@ def __init__(self,
batch_size=20,
ckpt_save_dir='./',
main_indicator='hmean',
num_columns_to_net=1, # TODO: parse eval cfg for short?
num_columns_of_labels=None,
val_interval=1,
val_start_epoch=1,
log_interval=1,
Expand All @@ -126,7 +140,7 @@ def __init__(self,
self.log_interval = log_interval
self.batch_size = batch_size
if self.loader_eval is not None:
self.net_evaluator = Evaluator(network, loader, loss_fn, postprocessor, metrics)
self.net_evaluator = Evaluator(network, loader, loss_fn, postprocessor, metrics, num_columns_to_net=num_columns_to_net, num_columns_of_labels=num_columns_of_labels)
self.main_indicator = main_indicator
self.best_perf = -1e8
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/st/test_train_eval_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,6 @@ def test_train_eval(task, val_while_train, gradient_accumulation_steps, clip_gra

if __name__ == '__main__':
#test_train_eval('rec', True, 1, True, 'filter_norm_and_bias')
test_train_eval('det', True, 1, False, None)
#test_train_eval('rec', True, 1, True, 'filter_norm_and_bias')
#test_train_eval('det', True, 1, False, None)
test_train_eval('rec', True, 1, True, 'filter_norm_and_bias')
#test_train_eval('rec', True, 2, False, None)
10 changes: 9 additions & 1 deletion tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,15 @@ def main(cfg):
# postprocess network prediction
metric = build_metric(cfg.metric)

net_evaluator = Evaluator(network, loader_eval, None, postprocessor, [metric])
net_evaluator = Evaluator(
network,
loader_eval,
loss_func=None,
postprocessor=postprocessor,
metrics=[metric],
num_columns_to_net=cfg.eval.dataset.get('num_columns_to_net', 1),
num_columns_of_labels=cfg.eval.dataset.get('num_columns_of_labels', None),
)

# log
print('='*40)
Expand Down
6 changes: 5 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def main(cfg):
cfg.eval.loader,
num_shards=device_num,
shard_id=rank_id,
is_train=False)
is_train=False,
refine_batch_size=True,
)

# create model
network = build_model(cfg.model)
Expand Down Expand Up @@ -133,6 +135,8 @@ def main(cfg):
batch_size=cfg.train.loader.batch_size,
ckpt_save_dir=cfg.train.ckpt_save_dir,
main_indicator=cfg.metric.main_indicator,
num_columns_to_net=cfg.eval.dataset.get('num_columns_to_net', 1),
num_columns_of_labels=cfg.eval.dataset.get('num_columns_of_labels', None),
val_interval=cfg.system.get('val_interval', 1),
val_start_epoch=cfg.system.get('val_start_epoch', 1),
log_interval=cfg.system.get('log_interval', 100)
Expand Down