Skip to content

Fix metric bugs in distributed mode and add unit test #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions mindocr/metrics/builder.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from .det_metrics import *
from .rec_metrics import *
from . import det_metrics
from . import rec_metrics
from . import rec_metrics

supported_metrics = det_metrics.__all__ + rec_metrics.__all__

# TODO: support multiple metrics
def build_metric(config):
def build_metric(config, device_num=1, **kwargs):
"""
Create the metric function.

Args:
config (dict): configuration for metric including metric `name` and also the kwargs specifically for each metric.
- name (str): metric function name, exactly the same as one of the supported metric class names

device_name (int): number of devices. If device_num > 1, metric will be computed in distributed mode, i.e., aggregate
intermediate variables (e.g., num_correct, TP) from all devices by `ops.AllReduce` op so as to correctly compute the metric on dispatched data.

Return:
nn.Metric

Example:
>>> # Create a RecMetric module for text recognition
>>> from mindocr.metrics import build_metric
Expand All @@ -28,8 +29,10 @@ def build_metric(config):

mn = config.pop('name')
if mn in supported_metrics:
device_num = 1 if device_num is None else device_num
config.update({'device_num': device_num})
metric = eval(mn)(**config)
else:
raise ValueError(f'Invalid metric name {mn}, support metrics are {supported_metrics}')

return metric
33 changes: 21 additions & 12 deletions mindocr/metrics/det_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,12 @@ def __call__(self, gt: List[dict], preds: List[np.ndarray]):


class DetMetric(nn.Metric):
def __init__(self, **kwargs):
def __init__(self, device_num=1, **kwargs):
super().__init__()
self._evaluator = DetectionIoUEvaluator()
self._gt_labels, self._det_labels = [], []
try:
self.device_num = get_group_size()
self.all_reduce = ops.AllReduce()
except (ValueError, RuntimeError):
self.device_num = 1
self.all_reduce = None
self.device_num = device_num
self.all_reduce = None if device_num==1 else ops.AllReduce()

def clear(self):
self._gt_labels, self._det_labels = [], []
Expand All @@ -97,10 +93,14 @@ def update(self, *inputs):

Args:
inputs (tuple): contain two elements preds, gt
preds (list): prediction output by postprocess in the form of [[(box, score)]]
gt (tuple): ground truth, order defined by output_columns in eval dataloader
preds (list[tuple]): text detection prediction as a list of tuple (polygon, confidence),
where polygon is in shape [num_boxes, 4, 2], confidence is in shape [num_boxes]
gts (tuple): ground truth - (polygons, ignore_tags), where polygons are in shape [num_images, num_boxes, 4, 2],
ignore_tags are in shape [num_images, num_boxes], which can be defined by output_columns in yaml
"""
preds, gts = inputs
print(preds)
print(gts[0].shape, gts[1].shape)
polys, ignore = gts[0].asnumpy().astype(np.float32), gts[1].asnumpy()

for sample_id in range(len(polys)):
Expand All @@ -119,6 +119,8 @@ def cal_matrix(self, det_lst, gt_lst):
fn = np.sum((gt_lst == 1) * (det_lst == 0))
fp = np.sum((gt_lst == 0) * (det_lst == 1))
return tp, fp, fn



def eval(self):
"""
Expand All @@ -139,11 +141,18 @@ def eval(self):
fp = float(self.all_reduce_fun(Tensor(fp, ms.float32)).asnumpy())
fn = float(self.all_reduce_fun(Tensor(fn, ms.float32)).asnumpy())

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be better instead of _safe_divide do this way?

recall, precision, f_score = 0., 0., 0.
if tp > 0:
    recall = tp / (tp + fn)
    ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also for extension of other metric computation to avoid zero division.

recall = tp / (tp + fn)
precision = tp / (tp + fp)
f_score = 2 * recall * precision / (recall + precision)
recall = _safe_divide(tp, (tp + fn))
precision = _safe_divide(tp, (tp + fp))
f_score = _safe_divide(2 * recall * precision, (recall + precision))
return {
'recall': recall,
'precision': precision,
'f-score': f_score
}


def _safe_divide(numerator, denominator, val_if_zero_divide=0.):
if denominator == 0:
return val_if_zero_divide
else:
return numerator / denominator
37 changes: 21 additions & 16 deletions mindocr/metrics/rec_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,17 @@ def __init__(self,
filter_ood=True,
lower=True,
print_flag=False,
device_num=1,
**kwargs):
super().__init__()
self.clear()
self.ignore_space = ignore_space
self.filter_ood = filter_ood
self.lower = lower
self.print_flag = print_flag

try:
self.device_num = get_group_size()
self.all_reduce = ops.AllReduce()
except (ValueError, RuntimeError):
self.device_num = 1
self.all_reduce = None

self.device_num = device_num
self.all_reduce = None if device_num==1 else ops.AllReduce()

# TODO: use parsed dictionary object
if character_dict_path is None:
Expand All @@ -57,9 +54,9 @@ def __init__(self,
self.dict.append(c)

def clear(self):
self._correct_num = 0
self._total_num = 0
self.norm_edit_dis = 0.0
self._correct_num = ms.Tensor(0, dtype=ms.int32)
self._total_num = ms.Tensor(0, dtype=ms.float32) # avoid int divisor
self._norm_edit_dis = ms.Tensor(0., dtype=ms.float32)

def update(self, *inputs):
"""
Expand Down Expand Up @@ -120,7 +117,7 @@ def update(self, *inputs):
print(pred, " :: ", label)

edit_distance = Levenshtein.normalized_distance(pred, label)
self.norm_edit_dis += edit_distance
self._norm_edit_dis += edit_distance
if pred == label:
self._correct_num += 1

Expand All @@ -137,13 +134,21 @@ def eval(self):
'Accuary can not be calculated, because the number of samples is 0.')
print('correct num: ', self._correct_num,
', total num: ', self._total_num)
sequence_accurancy = self._correct_num / self._total_num
norm_edit_distance = 1 - self.norm_edit_dis / self._total_num

if self.all_reduce:
sequence_accurancy = float(self.all_reduce_fun(Tensor(sequence_accurancy, ms.float32)).asnumpy())
norm_edit_distance = float(self.all_reduce_fun(Tensor(norm_edit_distance, ms.float32)).asnumpy())
# sum over all devices
correct_num = self.all_reduce_fun(self._correct_num)
norm_edit_dis = self.all_reduce_fun(self._norm_edit_dis)
total_num = self.all_reduce_fun(self._total_num)
else:
correct_num = self._correct_num
norm_edit_dis = self._norm_edit_dis
total_num = self._total_num

sequence_accurancy = float((correct_num / total_num).asnumpy())
norm_edit_distance = float((1 - norm_edit_dis / total_num).asnumpy())

return {'acc': sequence_accurancy / self.device_num, 'norm_edit_distance': norm_edit_distance / self.device_num}
return {'acc': sequence_accurancy, 'norm_edit_distance': norm_edit_distance}

if __name__ == '__main__':
gt = ['ba xla la! ', 'ba ']
Expand Down
3 changes: 2 additions & 1 deletion mindocr/postprocess/det_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __call__(self, pred):
# FIXME: dest_size is supposed to be the original image shape (pred.shape -> batch['shape'])
dest_size = np.array(pred.shape[:0:-1]) # w, h order
scale = dest_size / np.array(pred.shape[:0:-1])


# FIXME: output as dict, keep consistent return format to recognition
return [self._extract_preds(pr, segm, scale, dest_size) for pr, segm in zip(pred, segmentation)]

def _extract_preds(self, pred, bitmap, scale, dest_size):
Expand Down
5 changes: 0 additions & 5 deletions mindocr/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def eval(self, dataloader, num_columns_to_net=1, num_keys_of_labels=None):
for m in self.metrics:
m.clear()

# debug
# for param in self.net.get_parameters():
# print(param.name, param.value().sum())
for i, data in tqdm(enumerate(iterator), total=dataloader.get_dataset_size()):
# start = time.time()
# TODO: if network input is not just an image.
Expand All @@ -66,8 +63,6 @@ def eval(self, dataloader, num_columns_to_net=1, num_keys_of_labels=None):
if self.postprocessor is not None:
preds = self.postprocessor(net_preds) # {'polygons':, 'scores':} for text det

# print('GT polys:', gt[0])

# metric internal update
for m in self.metrics:
m.update(preds, gt)
Expand Down
78 changes: 52 additions & 26 deletions tests/ut/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,64 @@
sys.path.append('.')

import numpy as np

from mindocr.data.det_dataset import DetDataset
from mindocr.postprocess.det_postprocess import DBPostprocess
import mindspore as ms
from mindocr.metrics.det_metrics import DetMetric
from mindocr.metrics.rec_metrics import RecMetric


def test_det_metric():
# TODO: gen by DetDataset
data = np.load('./det_db_label_samples.npz')
polys, bmap, _, _, texts, ignore_tags = data['polys'], data['shrink_map'], data['threshold_map'], data['threshold_mask'], data['texts'], data['ignore_tags']
polys = np.array([polys])
bmap = np.array([bmap])
ignore_tags = np.array([ignore_tags])
print('GT polys: ', polys)
print('ignore_tags', ignore_tags)
print('texts', texts)

proc = DBPostprocess(thresh=0.3,
box_thresh=0.55,
max_candidates=1000,
unclip_ratio=1.5,
region_type='quad',
dest='binary',
score_mode='fast')
preds = proc({'binary': bmap})

m = DetMetric()
m.update(preds, (polys, ignore_tags))
pred_polys = [
[
[[0, 0], [0, 10], [10, 10], [10, 0]],
[[10, 10], [10, 20], [20, 20], [20, 10]],
[[20, 20], [20, 30], [30, 30], [30, 20]],
],
]
pred_polys = np.array(pred_polys, dtype=np.float32)
confs = np.array([[1.0, 0.8, 0.9]])
num_images = pred_polys.shape[0]
num_boxes = pred_polys.shape[1]
print(num_images, num_boxes)
preds = [(pred_polys[i], confs[i]) for i in range(num_images)]

res = m.eval()
print(res)
gt_polys = [
[
[[0, 0], [0, 9], [9, 9], [9, 0]],
[[10, 10], [-10, -20], [-20, -20], [-20, -10]],
[[20, 20], [20, 30], [30, 30], [30, 20]],
],
]
gt_polys = ms.Tensor(np.array(gt_polys, dtype=np.float32))
ignore_tags = ms.Tensor([[False, False, True]])
gts = (gt_polys, ignore_tags)

m = DetMetric()
m.update(preds, gts)

perf = m.eval()
print(perf)

# check correctness
assert perf['recall'] == 0.5
assert perf['precision'] == 0.5
assert perf['f-score'] == 0.5


def test_rec_metric():
gt = ['ba la la! ', 'ba ']
gt_len = [len('ba xla la!'), len('ba')]
pred = ['baxlala', 'ba']

m = RecMetric()
m.update({'texts': pred}, (gt, gt_len))
perf = m.eval()
print(perf)

# check correctness
assert perf['acc'] == 0.5
assert (perf['norm_edit_distance'] - 0.92857) < 1e-4


if __name__=='__main__':
test_det_metric()
#test_rec_metric()
6 changes: 3 additions & 3 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ def main(cfg):
scale_sense=loss_scale_manager,
drop_overflow_update=cfg.system.drop_overflow_update,
)
# postprocess, metric
# build postprocess and metric
postprocessor = None
metric = None
if cfg.system.val_while_train:
postprocessor = build_postprocess(cfg.postprocess)
# postprocess network prediction
metric = build_metric(cfg.metric)
postprocessor = build_postprocess(cfg.postprocess)
metric = build_metric(cfg.metric, device_num=device_num)

# build callbacks
eval_cb = EvalSaveCallback(
Expand Down