From ef683206ed9cb5ee3788a279b7a96f55d7b04e87 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Mon, 8 Aug 2022 11:37:46 +0800 Subject: [PATCH] [Api] vis hook and data flow api (#1185) * vis hook and data flow api * fix comment * add TODO for merging and rewriting after MultiDatasetWrapper --- configs/_base_/default_runtime.py | 9 +- .../drrg/drrg_r50_fpn_unet_1200e_ctw1500.py | 10 +- .../fcenet/fcenet_r50_fpn_1500e_icdar2015.py | 10 +- .../fcenet_r50dcnv2_fpn_1500e_ctw1500.py | 10 +- .../mask_rcnn_r50_fpn_160e_ctw1500.py | 10 +- .../mask_rcnn_r50_fpn_160e_icdar2015.py | 10 +- .../mask_rcnn_r50_fpn_160e_icdar2017.py | 10 +- .../panet_r18_fpem_ffm_600e_icdar2015.py | 10 +- .../psenet/psenet_r50_fpnf_600e_icdar2015.py | 10 +- .../textsnake_r50_fpn_unet_1200e_ctw1500.py | 10 +- configs/textrecog/abinet/base.py | 7 +- .../textrecog/crnn/crnn_academic_dataset.py | 6 +- .../master/master_r31_12e_ST_MJ_SA.py | 6 +- .../textrecog/master/master_toy_dataset.py | 6 +- .../nrtr/nrtr_modality_transform_academic.py | 6 +- .../nrtr_modality_transform_toy_dataset.py | 6 +- .../nrtr/nrtr_r31_1by16_1by8_academic.py | 6 +- .../nrtr/nrtr_r31_1by8_1by4_academic.py | 6 +- .../robustscanner_r31_academic.py | 6 +- .../sar/sar_r31_parallel_decoder_academic.py | 6 +- .../sar_r31_sequential_decoder_academic.py | 6 +- configs/textrecog/satrn/satrn_academic.py | 6 +- mmocr/engine/__init__.py | 1 + mmocr/engine/hooks/__init__.py | 4 + mmocr/engine/hooks/visualization_hook.py | 132 +++++++++++++++ mmocr/evaluation/metrics/f_metric.py | 8 +- mmocr/evaluation/metrics/hmean_iou_metric.py | 32 +--- mmocr/evaluation/metrics/recog_metric.py | 18 +- mmocr/visualization/textdet_visualizer.py | 25 +-- mmocr/visualization/textrecog_visualizer.py | 31 ++-- .../test_hooks/test_visualization_hook.py | 79 +++++++++ .../test_metrics/test_f_metric.py | 109 +++++------- .../test_metrics/test_hmean_iou_metric.py | 67 +++----- .../test_metrics/test_recog_metric.py | 157 +++++++----------- .../test_textdet_visualizer.py | 34 ++-- .../test_textrecog_visualizer.py | 31 ++-- 36 files changed, 549 insertions(+), 351 deletions(-) create mode 100644 mmocr/engine/hooks/__init__.py create mode 100644 mmocr/engine/hooks/visualization_hook.py create mode 100644 tests/test_engine/test_hooks/test_visualization_hook.py diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index 906586fe4..5298502cc 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -6,7 +6,14 @@ param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=1), sampler_seed=dict(type='DistSamplerSeedHook'), - sync_buffer=dict(type='SyncBuffersHook')) + visualization=dict( + type='VisualizationHook', + interval=1, + enable=False, + show=False, + draw_gt=False, + draw_pred=False), +) env_cfg = dict( cudnn_benchmark=True, diff --git a/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py b/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py index 8461d53f6..73a71614e 100644 --- a/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py +++ b/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py @@ -69,10 +69,16 @@ file_client_args=file_client_args, color_type='color_ignore_orientation'), dict(type='Resize', scale=(1024, 640), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True), dict( type='PackTextDetInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( diff --git a/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py b/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py index b3428b037..858d9a427 100644 --- a/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py +++ b/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py @@ -68,10 +68,16 @@ file_client_args=file_client_args, color_type='color_ignore_orientation'), dict(type='Resize', scale=(2260, 2260), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True), dict( type='PackTextDetInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( diff --git a/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py b/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py index d0ed170af..ccdf0da94 100644 --- a/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py +++ b/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py @@ -73,10 +73,16 @@ file_client_args=file_client_args, color_type='color_ignore_orientation'), dict(type='Resize', scale=(1080, 736), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True), dict( type='PackTextDetInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py index a2783259e..16f7d825f 100644 --- a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py +++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py @@ -50,10 +50,16 @@ file_client_args=file_client_args, color_type='color_ignore_orientation'), dict(type='Resize', scale=(1600, 1600), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True), dict( type='PackTextDetInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py index 8f53d52ad..824738a13 100644 --- a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py +++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py @@ -50,10 +50,16 @@ file_client_args=file_client_args, color_type='color_ignore_orientation'), dict(type='Resize', scale=(1920, 1920), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True), dict( type='PackTextDetInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py index eac0219d0..b36ec5e9f 100644 --- a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py +++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py @@ -49,10 +49,16 @@ file_client_args=file_client_args, color_type='color_ignore_orientation'), dict(type='Resize', scale=(1920, 1920), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True), dict( type='PackTextDetInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( diff --git a/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py b/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py index 6a8a3f550..5155434c8 100644 --- a/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py +++ b/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py @@ -49,10 +49,16 @@ scale_divisor=1, ratio_range=(1.0, 1.0), aspect_ratio_range=(1.0, 1.0)), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True), dict( type='PackTextDetInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py index ac6a13f8e..dde387cc3 100644 --- a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py +++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py @@ -46,10 +46,16 @@ file_client_args=file_client_args, color_type='color_ignore_orientation'), dict(type='Resize', scale=(2240, 2240), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True), dict( type='PackTextDetInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( diff --git a/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py index 24dbd9378..d8b19ed53 100644 --- a/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py +++ b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py @@ -63,10 +63,16 @@ file_client_args=file_client_args, color_type='color_ignore_orientation'), dict(type='Resize', scale=(1333, 736), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True), dict( type='PackTextDetInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( diff --git a/configs/textrecog/abinet/base.py b/configs/textrecog/abinet/base.py index a00d25889..c6290813c 100644 --- a/configs/textrecog/abinet/base.py +++ b/configs/textrecog/abinet/base.py @@ -76,12 +76,13 @@ ] test_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict(type='LoadOCRAnnotations', with_text=True), dict(type='Resize', scale=(128, 32)), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] train_dataloader = dict( diff --git a/configs/textrecog/crnn/crnn_academic_dataset.py b/configs/textrecog/crnn/crnn_academic_dataset.py index 35f7c93c9..d2fc10d4e 100644 --- a/configs/textrecog/crnn/crnn_academic_dataset.py +++ b/configs/textrecog/crnn/crnn_academic_dataset.py @@ -38,10 +38,12 @@ min_width=32, max_width=None, width_divisor=16), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] train_dataloader = dict( diff --git a/configs/textrecog/master/master_r31_12e_ST_MJ_SA.py b/configs/textrecog/master/master_r31_12e_ST_MJ_SA.py index 4446d7b6f..e33f7041e 100644 --- a/configs/textrecog/master/master_r31_12e_ST_MJ_SA.py +++ b/configs/textrecog/master/master_r31_12e_ST_MJ_SA.py @@ -39,10 +39,12 @@ max_width=160, width_divisor=16), dict(type='PadToWidth', width=160), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] train_dataloader = dict( diff --git a/configs/textrecog/master/master_toy_dataset.py b/configs/textrecog/master/master_toy_dataset.py index 47668f6c6..95c258579 100644 --- a/configs/textrecog/master/master_toy_dataset.py +++ b/configs/textrecog/master/master_toy_dataset.py @@ -35,10 +35,12 @@ max_width=160, width_divisor=16), dict(type='PadToWidth', width=160), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] train_dataloader = dict( diff --git a/configs/textrecog/nrtr/nrtr_modality_transform_academic.py b/configs/textrecog/nrtr/nrtr_modality_transform_academic.py index 5c1811431..0f87c3177 100644 --- a/configs/textrecog/nrtr/nrtr_modality_transform_academic.py +++ b/configs/textrecog/nrtr/nrtr_modality_transform_academic.py @@ -42,10 +42,12 @@ max_width=160, width_divisor=16), dict(type='PadToWidth', width=160), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] train_dataloader = dict( diff --git a/configs/textrecog/nrtr/nrtr_modality_transform_toy_dataset.py b/configs/textrecog/nrtr/nrtr_modality_transform_toy_dataset.py index 2c146718c..b4106b2b4 100644 --- a/configs/textrecog/nrtr/nrtr_modality_transform_toy_dataset.py +++ b/configs/textrecog/nrtr/nrtr_modality_transform_toy_dataset.py @@ -34,10 +34,12 @@ max_width=160, width_divisor=16), dict(type='PadToWidth', width=160), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] train_dataloader = dict( diff --git a/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py b/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py index 53a93c971..ed76d29c9 100644 --- a/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py +++ b/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py @@ -42,10 +42,12 @@ max_width=160, width_divisor=16), dict(type='PadToWidth', width=160), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] train_dataloader = dict( diff --git a/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py b/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py index e66f29275..67c80804b 100644 --- a/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py +++ b/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py @@ -44,10 +44,12 @@ max_width=160, width_divisor=16), dict(type='PadToWidth', width=160), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] train_dataloader = dict( diff --git a/configs/textrecog/robust_scanner/robustscanner_r31_academic.py b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py index b330c8c8c..6560d81b3 100644 --- a/configs/textrecog/robust_scanner/robustscanner_r31_academic.py +++ b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py @@ -36,10 +36,12 @@ max_width=160, width_divisor=4), dict(type='PadToWidth', width=160), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] # dataset settings diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py index a46d62b52..6f4f89bbf 100644 --- a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py +++ b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py @@ -38,10 +38,12 @@ max_width=160, width_divisor=4), dict(type='PadToWidth', width=160), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] # dataset settings diff --git a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py index b5323b30d..69cc76723 100644 --- a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py +++ b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py @@ -38,10 +38,12 @@ max_width=160, width_divisor=4), dict(type='PadToWidth', width=160), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] # dataset settings diff --git a/configs/textrecog/satrn/satrn_academic.py b/configs/textrecog/satrn/satrn_academic.py index 243ea0a91..30a95b1bf 100644 --- a/configs/textrecog/satrn/satrn_academic.py +++ b/configs/textrecog/satrn/satrn_academic.py @@ -58,10 +58,12 @@ test_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), dict(type='Resize', scale=(100, 32), keep_ratio=False), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadOCRAnnotations', with_text=True), dict( type='PackTextRecogInputs', - meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio', - 'instances')) + meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] train_dataloader = dict( diff --git a/mmocr/engine/__init__.py b/mmocr/engine/__init__.py index 03ee0e85d..c2db2ce5f 100644 --- a/mmocr/engine/__init__.py +++ b/mmocr/engine/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .hooks import * # NOQA from .runner import * # NOQA diff --git a/mmocr/engine/hooks/__init__.py b/mmocr/engine/hooks/__init__.py new file mode 100644 index 000000000..62d8c9e56 --- /dev/null +++ b/mmocr/engine/hooks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .visualization_hook import VisualizationHook + +__all__ = ['VisualizationHook'] diff --git a/mmocr/engine/hooks/visualization_hook.py b/mmocr/engine/hooks/visualization_hook.py new file mode 100644 index 000000000..5c183d1f8 --- /dev/null +++ b/mmocr/engine/hooks/visualization_hook.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Sequence, Union + +import mmcv +from mmengine.hooks import Hook +from mmengine.runner import Runner +from mmengine.visualization import Visualizer + +from mmocr.registry import HOOKS +from mmocr.structures import TextDetDataSample, TextRecogDataSample + + +# TODO Files with the same name will be overwritten for multi datasets +@HOOKS.register_module() +class VisualizationHook(Hook): + """Detection Visualization Hook. Used to visualize validation and testing + process prediction results. + + Args: + enable (bool): Whether to enable this hook. Defaults to False. + interval (int): The interval of visualization. Defaults to 50. + score_thr (float): The threshold to visualize the bboxes + and masks. It's only useful for text detection. Defaults to 0.3. + show (bool): Whether to display the drawn image. Defaults to False. + wait_time (float): The interval of show in seconds. Defaults + to 0. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__( + self, + enable: bool = False, + interval: int = 50, + score_thr: float = 0.3, + show: bool = False, + draw_pred: bool = False, + draw_gt: bool = False, + wait_time: float = 0., + file_client_args: dict = dict(backend='disk') + ) -> None: + self._visualizer: Visualizer = Visualizer.get_current_instance() + self.interval = interval + self.score_thr = score_thr + self.show = show + self.draw_pred = draw_pred + self.draw_gt = draw_gt + self.wait_time = wait_time + self.file_client_args = file_client_args.copy() + self.file_client = None + self.enable = enable + + # TODO after MultiDatasetWrapper, rewrites this function and try to merge + # with after_val_iter and after_test_iter + def after_val_iter(self, runner: Runner, batch_idx: int, + data_batch: Sequence[dict], + outputs: Sequence[Union[TextDetDataSample, + TextRecogDataSample]]) -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (Sequence[dict]): Data from dataloader. + outputs (Sequence[:obj:`TextDetDataSample` or + :obj:`TextRecogDataSample`]): Outputs from model. + """ + # TODO: data_batch does not include annotation information + if self.enable is False: + return + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + # There is no guarantee that the same batch of images + # is visualized for each evaluation. + total_curr_iter = runner.iter + batch_idx + + # Visualize only the first data + if total_curr_iter % self.interval == 0: + for output in outputs: + img_path = output.img_path + img_bytes = self.file_client.get(img_path) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + self._visualizer.add_datasample( + osp.splitext(osp.basename(img_path))[0], + img, + data_sample=output, + draw_gt=self.draw_gt, + draw_pred=self.draw_pred, + show=self.show, + wait_time=self.wait_time, + pred_score_thr=self.score_thr, + step=total_curr_iter) + + def after_test_iter(self, runner: Runner, batch_idx: int, + data_batch: Sequence[dict], + outputs: Sequence[Union[TextDetDataSample, + TextRecogDataSample]]) -> None: + """Run after every testing iterations. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (Sequence[dict]): Data from dataloader. + outputs (Sequence[:obj:`TextDetDataSample` or + :obj:`TextRecogDataSample`]): Outputs from model. + """ + + if self.enable is False: + return + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + for output in outputs: + img_path = output.img_path + img_bytes = self.file_client.get(img_path) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + + self._visualizer.add_datasample( + osp.splitext(osp.basename(img_path))[0], + img, + data_sample=output, + show=self.show, + draw_gt=self.draw_gt, + draw_pred=self.draw_pred, + wait_time=self.wait_time, + pred_score_thr=self.score_thr, + step=batch_idx) diff --git a/mmocr/evaluation/metrics/f_metric.py b/mmocr/evaluation/metrics/f_metric.py index d2ac2f44d..28d52ed19 100644 --- a/mmocr/evaluation/metrics/f_metric.py +++ b/mmocr/evaluation/metrics/f_metric.py @@ -96,9 +96,11 @@ def process(self, data_batch: Sequence[Dict], data_batch (Sequence[Dict]): A batch of gts. predictions (Sequence[Dict]): A batch of outputs from the model. """ - for gt, pred in zip(data_batch, predictions): - pred_labels = pred.get('pred_instances').get(self.key).cpu() - gt_labels = gt.get('data_sample').get('gt_instances').get(self.key) + for data_samples in predictions: + pred_labels = data_samples.get('pred_instances').get( + self.key).cpu() + gt_labels = data_samples.get('gt_instances').get(self.key).cpu() + result = dict( pred_labels=pred_labels.flatten(), gt_labels=gt_labels.flatten()) diff --git a/mmocr/evaluation/metrics/hmean_iou_metric.py b/mmocr/evaluation/metrics/hmean_iou_metric.py index 42b36893e..2816c598c 100644 --- a/mmocr/evaluation/metrics/hmean_iou_metric.py +++ b/mmocr/evaluation/metrics/hmean_iou_metric.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence import numpy as np import torch @@ -85,17 +85,20 @@ def process(self, data_batch: Sequence[Dict], predictions (Sequence[Dict]): A batch of outputs from the model. """ - for pred, gt in zip(predictions, data_batch): + for data_sample in predictions: - pred_instances = pred.get('pred_instances') + pred_instances = data_sample.get('pred_instances') pred_polygons = pred_instances.get('polygons') pred_scores = pred_instances.get('scores') if isinstance(pred_scores, torch.Tensor): pred_scores = pred_scores.cpu().numpy() pred_scores = np.array(pred_scores, dtype=np.float32) - gt_polys, gt_ignore_flags = self._polys_from_ann( - gt['data_sample']['instances']) + gt_instances = data_sample.get('gt_instances') + gt_polys = gt_instances.get('polygons') + gt_ignore_flags = gt_instances.get('ignored') + if isinstance(gt_ignore_flags, torch.Tensor): + gt_ignore_flags = gt_ignore_flags.cpu().numpy() gt_polys = polys2shapely(gt_polys) pred_polys = polys2shapely(pred_polygons) @@ -220,22 +223,3 @@ def _filter_preds(self, pred_polys: List[Polygon], gt_polys: List[Polygon], def _true_indexes(self, array: np.ndarray) -> np.ndarray: """Get indexes of True elements from a 1D boolean array.""" return np.where(array)[0] - - def _polys_from_ann(self, ann: Dict) -> Tuple[List, List]: - """Get GT polygons from annotations. - - Args: - ann (dict): The ground-truth annotation. - - Returns: - tuple[list[np.array], np.array]: Returns a tuple - ``(polys, gt_ignore_flags)``, where ``polys`` is the ground-truth - polygon instances and ``gt_ignore_flags`` represents whether the - corresponding instance should be ignored. - """ - polys = [] - gt_ignore_flags = [] - for instance in ann: - gt_ignore_flags.append(instance['ignore']) - polys.append(np.array(instance['polygon'], dtype=np.float32)) - return polys, np.array(gt_ignore_flags, dtype=bool) diff --git a/mmocr/evaluation/metrics/recog_metric.py b/mmocr/evaluation/metrics/recog_metric.py index a206a2c33..1767f4ed5 100644 --- a/mmocr/evaluation/metrics/recog_metric.py +++ b/mmocr/evaluation/metrics/recog_metric.py @@ -60,12 +60,12 @@ def process(self, data_batch: Sequence[Dict], data_batch (Sequence[Dict]): A batch of gts. predictions (Sequence[Dict]): A batch of outputs from the model. """ - for gt, pred in zip(data_batch, predictions): + for data_sample in predictions: match_num = 0 match_ignore_case_num = 0 match_ignore_case_symbol_num = 0 - pred_text = pred.get('pred_text').get('item') - gt_text = gt.get('data_sample').get('instances')[0].get('text') + pred_text = data_sample.get('pred_text').get('item') + gt_text = data_sample.get('gt_text').get('item') if 'ignore_case' in self.mode or 'ignore_case_symbol' in self.mode: pred_text_lower = pred_text.lower() gt_text_lower = gt_text.lower() @@ -158,9 +158,9 @@ def process(self, data_batch: Sequence[Dict], data_batch (Sequence[Dict]): A batch of gts. predictions (Sequence[Dict]): A batch of outputs from the model. """ - for gt, pred in zip(data_batch, predictions): - pred_text = pred.get('pred_text').get('item') - gt_text = gt.get('data_sample').get('instances')[0].get('text') + for data_sample in predictions: + pred_text = data_sample.get('pred_text').get('item') + gt_text = data_sample.get('gt_text').get('item') gt_text_lower = gt_text.lower() pred_text_lower = pred_text.lower() gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) @@ -258,9 +258,9 @@ def process(self, data_batch: Sequence[Dict], data_batch (Sequence[Dict]): A batch of gts. predictions (Sequence[Dict]): A batch of outputs from the model. """ - for gt, pred in zip(data_batch, predictions): - pred_text = pred.get('pred_text').get('item') - gt_text = gt.get('data_sample').get('instances')[0].get('text') + for data_sample in predictions: + pred_text = data_sample.get('pred_text').get('item') + gt_text = data_sample.get('gt_text').get('item') gt_text_lower = gt_text.lower() pred_text_lower = pred_text.lower() gt_text_lower_ignore = self.valid_symbol.sub('', gt_text_lower) diff --git a/mmocr/visualization/textdet_visualizer.py b/mmocr/visualization/textdet_visualizer.py index 604583750..29d5bc2c5 100644 --- a/mmocr/visualization/textdet_visualizer.py +++ b/mmocr/visualization/textdet_visualizer.py @@ -65,8 +65,7 @@ def __init__(self, def add_datasample(self, name: str, image: np.ndarray, - gt_sample: Optional['TextDetDataSample'] = None, - pred_sample: Optional['TextDetDataSample'] = None, + data_sample: Optional['TextDetDataSample'] = None, draw_gt: bool = True, draw_pred: bool = True, show: bool = False, @@ -88,10 +87,9 @@ def add_datasample(self, Args: name (str): The image identifier. image (np.ndarray): The image to draw. - gt_sample (:obj:`TextDetDataSample`, optional): GT - TextDetDataSample. Defaults to None. - pred_sample (:obj:`TextDetDataSample`, optional): Predicted - TextDetDataSample. Defaults to None. + data_sample (:obj:`TextDetDataSample`, optional): + TextDetDataSample which contains gt and prediction. Defaults + to None. draw_gt (bool): Whether to draw GT TextDetDataSample. Defaults to True. draw_pred (bool): Whether to draw Predicted TextDetDataSample. @@ -106,8 +104,9 @@ def add_datasample(self, gt_img_data = None pred_img_data = None - if draw_gt and gt_sample is not None and 'gt_instances' in gt_sample: - gt_instances = gt_sample.gt_instances + if (draw_gt and data_sample is not None + and 'gt_instances' in data_sample): + gt_instances = data_sample.gt_instances self.set_image(image) @@ -132,9 +131,9 @@ def add_datasample(self, gt_img_data = self.get_image() - if draw_pred and pred_sample is not None \ - and 'pred_instances' in pred_sample: - pred_instances = pred_sample.pred_instances + if draw_pred and data_sample is not None \ + and 'pred_instances' in data_sample: + pred_instances = data_sample.pred_instances pred_instances = pred_instances[ pred_instances.scores > pred_score_thr].cpu() @@ -166,8 +165,10 @@ def add_datasample(self, drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) elif gt_img_data is not None: drawn_img = gt_img_data - else: + elif pred_img_data is not None: drawn_img = pred_img_data + else: + drawn_img = image if show: self.show(drawn_img, win_name=name, wait_time=wait_time) diff --git a/mmocr/visualization/textrecog_visualizer.py b/mmocr/visualization/textrecog_visualizer.py index cb17d3018..004ffda75 100644 --- a/mmocr/visualization/textrecog_visualizer.py +++ b/mmocr/visualization/textrecog_visualizer.py @@ -49,12 +49,12 @@ def __init__(self, def add_datasample(self, name: str, image: np.ndarray, - gt_sample: Optional['TextRecogDataSample'] = None, - pred_sample: Optional['TextRecogDataSample'] = None, + data_sample: Optional['TextRecogDataSample'] = None, draw_gt: bool = True, draw_pred: bool = True, show: bool = False, wait_time: int = 0, + pred_score_thr: float = None, out_file: Optional[str] = None, step=0) -> None: """Visualize datasample and save to all backends. @@ -71,10 +71,9 @@ def add_datasample(self, Args: name (str): The image title. Defaults to 'image'. image (np.ndarray): The image to draw. - gt_sample (:obj:`TextRecogDataSample`, optional): GT - TextRecogDataSample. Defaults to None. - pred_sample (:obj:`TextRecogDataSample`, optional): Predicted - TextRecogDataSample. Defaults to None. + data_sample (:obj:`TextRecogDataSample`, optional): + TextRecogDataSample which contains gt and prediction. + Defaults to None. draw_gt (bool): Whether to draw GT TextRecogDataSample. Defaults to True. draw_pred (bool): Whether to draw Predicted TextRecogDataSample. @@ -83,6 +82,8 @@ def add_datasample(self, wait_time (float): The interval of show (s). Defaults to 0. out_file (str): Path to output file. Defaults to None. step (int): Global step value to record. Defaults to 0. + pred_score_thr (float): Threshold of prediction score. It's not + used in this function. Defaults to None. """ gt_img_data = None pred_img_data = None @@ -93,11 +94,11 @@ def add_datasample(self, if image.ndim == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) - if draw_gt and gt_sample is not None and 'gt_text' in gt_sample: - gt_text = gt_sample.gt_text.item + if draw_gt and data_sample is not None and 'gt_text' in data_sample: + gt_text = data_sample.gt_text.item empty_img = np.full_like(image, 255) self.set_image(empty_img) - font_size = 0.5 * resize_width / len(gt_text) + font_size = 0.5 * resize_width / (len(gt_text) + 1) self.draw_texts( gt_text, np.array([resize_width / 2, resize_height / 2]), @@ -108,12 +109,12 @@ def add_datasample(self, gt_text_image = self.get_image() gt_img_data = np.concatenate((image, gt_text_image), axis=0) - if (draw_pred and pred_sample is not None - and 'pred_text' in pred_sample): - pred_text = pred_sample.pred_text.item + if (draw_pred and data_sample is not None + and 'pred_text' in data_sample): + pred_text = data_sample.pred_text.item empty_img = np.full_like(image, 255) self.set_image(empty_img) - font_size = 0.5 * resize_width / len(pred_text) + font_size = 0.5 * resize_width / (len(pred_text) + 1) self.draw_texts( pred_text, np.array([resize_width / 2, resize_height / 2]), @@ -128,8 +129,10 @@ def add_datasample(self, drawn_img = np.concatenate((gt_img_data, pred_text_image), axis=0) elif gt_img_data is not None: drawn_img = gt_img_data - else: + elif pred_img_data is not None: drawn_img = pred_img_data + else: + drawn_img = image if show: self.show(drawn_img, win_name=name, wait_time=wait_time) diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py new file mode 100644 index 000000000..e932855c0 --- /dev/null +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import shutil +import time +from unittest import TestCase +from unittest.mock import Mock + +import torch +from mmengine.data import InstanceData + +from mmocr.engine.hooks import VisualizationHook +from mmocr.structures import TextDetDataSample +from mmocr.visualization import TextDetLocalVisualizer + + +def _rand_bboxes(num_boxes, h, w): + cx, cy, bw, bh = torch.rand(num_boxes, 4).T + + tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w).unsqueeze(0) + tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h).unsqueeze(0) + br_x = ((cx * w) + (w * bw / 2)).clamp(0, w).unsqueeze(0) + br_y = ((cy * h) + (h * bh / 2)).clamp(0, h).unsqueeze(0) + + bboxes = torch.cat([tl_x, tl_y, br_x, br_y], dim=0).T + return bboxes + + +class TestVisualizationHook(TestCase): + + def setUp(self) -> None: + + data_sample = TextDetDataSample() + data_sample.set_metainfo({ + 'img_path': + osp.join( + osp.dirname(__file__), + '../../data/det_toy_dataset/imgs/test/img_1.jpg') + }) + + pred_instances = InstanceData() + pred_instances.bboxes = _rand_bboxes(5, 10, 12) + pred_instances.labels = torch.randint(0, 2, (5, )) + pred_instances.scores = torch.rand((5, )) + + data_sample.pred_instances = pred_instances + self.outputs = [data_sample] * 2 + self.data_batch = None + + def test_after_val_iter(self): + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + TextDetLocalVisualizer.get_instance( + 'visualizer_val', + vis_backends=[dict(type='LocalVisBackend', img_save_dir='')], + save_dir=timestamp) + runner = Mock() + runner.iter = 1 + hook = VisualizationHook(enable=True, interval=1) + self.assertFalse(osp.exists(timestamp)) + hook.after_val_iter(runner, 1, self.data_batch, self.outputs) + self.assertTrue(osp.exists(timestamp)) + shutil.rmtree(timestamp) + + def test_after_test_iter(self): + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + TextDetLocalVisualizer.get_instance( + 'visualizer_test', + vis_backends=[dict(type='LocalVisBackend', img_save_dir='')], + save_dir=timestamp) + runner = Mock() + runner.iter = 1 + + hook = VisualizationHook(enable=False) + hook.after_test_iter(runner, 1, self.data_batch, self.outputs) + self.assertFalse(osp.exists(timestamp)) + + hook = VisualizationHook(enable=True) + hook.after_test_iter(runner, 1, self.data_batch, self.outputs) + self.assertTrue(osp.exists(timestamp)) + shutil.rmtree(timestamp) diff --git a/tests/test_evaluation/test_metrics/test_f_metric.py b/tests/test_evaluation/test_metrics/test_f_metric.py index cac19075b..7d91673e9 100644 --- a/tests/test_evaluation/test_metrics/test_f_metric.py +++ b/tests/test_evaluation/test_metrics/test_f_metric.py @@ -28,102 +28,79 @@ def test_init(self): def test_macro_f1(self): mode = 'macro' - gts_cases = [ + preds_cases = [ [ - dict( - data_sample=KIEDataSample( - gt_instances=InstanceData( - labels=torch.LongTensor([0, 1, 4])))) + KIEDataSample( + pred_instances=InstanceData( + labels=torch.LongTensor([0, 1, 2])), + gt_instances=InstanceData( + labels=torch.LongTensor([0, 1, 4]))) ], [ - dict( - data_sample=KIEDataSample( - gt_instances=InstanceData( - labels=torch.LongTensor([0, 1])))), - dict( - data_sample=KIEDataSample( - gt_instances=InstanceData( - labels=torch.LongTensor([4])))) - ], + KIEDataSample( + gt_instances=InstanceData(labels=torch.LongTensor([0, 1])), + pred_instances=InstanceData( + labels=torch.LongTensor([0, 1]))), + KIEDataSample( + gt_instances=InstanceData(labels=torch.LongTensor([4])), + pred_instances=InstanceData(labels=torch.LongTensor([2]))) + ] ] - preds_cases = [[ - KIEDataSample( - pred_instances=InstanceData( - labels=torch.LongTensor([0, 1, 2]))) - ], - [ - KIEDataSample( - pred_instances=InstanceData( - labels=torch.LongTensor([0, 1]))), - KIEDataSample( - pred_instances=InstanceData( - labels=torch.LongTensor([2]))) - ]] # num_classes < the maximum label index metric = F1Metric(num_classes=3, ignored_classes=[1]) - metric.process(gts_cases[0], preds_cases[0]) + metric.process(None, preds_cases[0]) with self.assertRaises(AssertionError): metric.evaluate(size=1) - for gts, preds in zip(gts_cases, preds_cases): + for preds in preds_cases: metric = F1Metric(num_classes=5, mode=mode) - metric.process(gts, preds) - result = metric.evaluate(size=len(gts)) + metric.process(None, preds) + result = metric.evaluate(size=len(preds)) self.assertAlmostEqual(result['kie/macro_f1'], 0.4) # Test ignored_classes metric = F1Metric(num_classes=5, ignored_classes=[1], mode=mode) - metric.process(gts, preds) - result = metric.evaluate(size=len(gts)) + metric.process(None, preds) + result = metric.evaluate(size=len(preds)) self.assertAlmostEqual(result['kie/macro_f1'], 0.25) # Test cared_classes metric = F1Metric( num_classes=5, cared_classes=[0, 2, 3, 4], mode=mode) - metric.process(gts, preds) - result = metric.evaluate(size=len(gts)) + metric.process(None, preds) + result = metric.evaluate(size=len(preds)) self.assertAlmostEqual(result['kie/macro_f1'], 0.25) def test_micro_f1(self): mode = 'micro' - gts_cases = [[ - dict( - data_sample=KIEDataSample( - gt_instances=InstanceData( - labels=torch.LongTensor([0, 1, 0, 1, 2])))) - ], - [ - dict( - data_sample=KIEDataSample( - gt_instances=InstanceData( - labels=torch.LongTensor([0, 1, 2])))), - dict( - data_sample=KIEDataSample( - gt_instances=InstanceData( - labels=torch.LongTensor([0, 1])))) - ]] preds_cases = [[ KIEDataSample( + gt_instances=InstanceData( + labels=torch.LongTensor([0, 1, 0, 1, 2])), pred_instances=InstanceData( labels=torch.LongTensor([0, 1, 2, 2, 0]))) ], [ KIEDataSample( + gt_instances=InstanceData( + labels=torch.LongTensor([0, 1, 2])), pred_instances=InstanceData( labels=torch.LongTensor([0, 1, 0]))), KIEDataSample( + gt_instances=InstanceData( + labels=torch.LongTensor([0, 1])), pred_instances=InstanceData( labels=torch.LongTensor([2, 2]))) ]] # num_classes < the maximum label index metric = F1Metric(num_classes=1, ignored_classes=[0], mode=mode) - metric.process(gts_cases[0], preds_cases[0]) + metric.process(None, preds_cases[0]) with self.assertRaises(AssertionError): metric.evaluate(size=1) - for gts, preds in zip(gts_cases, preds_cases): + for preds in preds_cases: # class 0: tp: 1, fp: 1, fn: 1 # class 1: tp: 1, fp: 1, fn: 0 # class 2: tp: 0, fp: 1, fn: 2 @@ -131,13 +108,13 @@ def test_micro_f1(self): # f1: 0.4 metric = F1Metric(num_classes=3, mode=mode) - metric.process(gts, preds) - result = metric.evaluate(size=len(gts)) + metric.process(None, preds) + result = metric.evaluate(size=len(preds)) self.assertAlmostEqual(result['kie/micro_f1'], 0.4, delta=0.01) metric = F1Metric(num_classes=5, mode=mode) - metric.process(gts, preds) - result = metric.evaluate(size=len(gts)) + metric.process(None, preds) + result = metric.evaluate(size=len(preds)) self.assertAlmostEqual(result['kie/micro_f1'], 0.4, delta=0.01) # class 0: tp: 1, fp: 1, fn: 1 @@ -146,26 +123,22 @@ def test_micro_f1(self): # f1: 0.285 metric = F1Metric(num_classes=5, ignored_classes=[1], mode=mode) - metric.process(gts, preds) - result = metric.evaluate(size=len(gts)) + metric.process(None, preds) + result = metric.evaluate(size=len(preds)) self.assertAlmostEqual(result['kie/micro_f1'], 0.285, delta=0.001) metric = F1Metric( num_classes=5, cared_classes=[0, 2, 3, 4], mode=mode) - metric.process(gts, preds) - result = metric.evaluate(size=len(gts)) + metric.process(None, preds) + result = metric.evaluate(size=len(preds)) self.assertAlmostEqual(result['kie/micro_f1'], 0.285, delta=0.001) def test_arguments(self): mode = ['micro', 'macro'] - gts = [ - dict( - data_sample=KIEDataSample( - gt_instances=InstanceData( - test_labels=torch.LongTensor([0, 1, 0, 1, 2])))) - ] preds = [ KIEDataSample( + gt_instances=InstanceData( + test_labels=torch.LongTensor([0, 1, 0, 1, 2])), pred_instances=InstanceData( test_labels=torch.LongTensor([0, 1, 2, 2, 0]))) ] @@ -178,7 +151,7 @@ def test_arguments(self): # macro_f1: metric = F1Metric(num_classes=3, mode=mode, key='test_labels') - metric.process(gts, preds) + metric.process(None, preds) result = metric.evaluate(size=1) self.assertAlmostEqual(result['kie/micro_f1'], 0.4, delta=0.01) self.assertAlmostEqual(result['kie/macro_f1'], 0.39, delta=0.01) diff --git a/tests/test_evaluation/test_metrics/test_hmean_iou_metric.py b/tests/test_evaluation/test_metrics/test_hmean_iou_metric.py index 2ed232943..af9f7dd82 100644 --- a/tests/test_evaluation/test_metrics/test_hmean_iou_metric.py +++ b/tests/test_evaluation/test_metrics/test_hmean_iou_metric.py @@ -26,60 +26,47 @@ def setUp(self): pred_d is ignored in the recall computation since it overlaps gt_d_ignored and the precision > ignore_precision_thr. """ - # prepare gt - self.gt = [{ - 'data_sample': { - 'instances': [{ - 'polygon': [0, 0, 1, 0, 1, 1, 0, 1], - 'ignore': False - }, { - 'polygon': [2, 0, 3, 0, 3, 1, 2, 1], - 'ignore': False - }, { - 'polygon': [10, 0, 11, 0, 11, 1, 10, 1], - 'ignore': False - }, { - 'polygon': [1, 0, 2, 0, 2, 1, 1, 1], - 'ignore': True - }] - } - }, { - 'data_sample': { - 'instances': [{ - 'polygon': [0, 0, 1, 0, 1, 1, 0, 1], - 'ignore': False - }], - } - }] - - # prepare pred - pred_data_sample = TextDetDataSample() - pred_data_sample.pred_instances = InstanceData() - pred_data_sample.pred_instances.polygons = [ + data_sample = TextDetDataSample() + gt_instances = InstanceData() + gt_instances.polygons = [ + torch.FloatTensor([0, 0, 1, 0, 1, 1, 0, 1]), + torch.FloatTensor([2, 0, 3, 0, 3, 1, 2, 1]), + torch.FloatTensor([10, 0, 11, 0, 11, 1, 10, 1]), + torch.FloatTensor([1, 0, 2, 0, 2, 1, 1, 1]), + ] + gt_instances.ignored = np.bool_([False, False, False, True]) + pred_instances = InstanceData() + pred_instances.polygons = [ torch.FloatTensor([0, 0, 1, 0, 1, 1, 0, 1]), torch.FloatTensor([2, 0.1, 3, 0.1, 3, 1.1, 2, 1.1]), torch.FloatTensor([1, 1, 2, 1, 2, 2, 1, 2]), torch.FloatTensor([1, -0.5, 2, -0.5, 2, 0.5, 1, 0.5]), ] - pred_data_sample.pred_instances.scores = torch.FloatTensor( - [1, 1, 1, 0.001]) - predictions = [pred_data_sample.to_dict()] - - pred_data_sample = TextDetDataSample() - pred_data_sample.pred_instances = InstanceData() - pred_data_sample.pred_instances.polygons = [ + pred_instances.scores = torch.FloatTensor([1, 1, 1, 0.001]) + data_sample.gt_instances = gt_instances + data_sample.pred_instances = pred_instances + predictions = [data_sample.to_dict()] + + data_sample = TextDetDataSample() + gt_instances = InstanceData() + gt_instances.polygons = [torch.FloatTensor([0, 0, 1, 0, 1, 1, 0, 1])] + gt_instances.ignored = np.bool_([False]) + pred_instances = InstanceData() + pred_instances.polygons = [ torch.FloatTensor([0, 0, 1, 0, 1, 1, 0, 1]), torch.FloatTensor([0, 0, 1, 0, 1, 1, 0, 1]) ] - pred_data_sample.pred_instances.scores = torch.FloatTensor([1, 0.95]) - predictions.append(pred_data_sample.to_dict()) + pred_instances.scores = torch.FloatTensor([1, 0.95]) + data_sample.gt_instances = gt_instances + data_sample.pred_instances = pred_instances + predictions.append(data_sample.to_dict()) self.predictions = predictions def test_hmean_iou(self): metric = HmeanIOUMetric(prefix='mmocr') - metric.process(self.gt, self.predictions) + metric.process(None, self.predictions) eval_results = metric.evaluate(size=2) precision = 3 / 5 diff --git a/tests/test_evaluation/test_metrics/test_recog_metric.py b/tests/test_evaluation/test_metrics/test_recog_metric.py index a902c68d8..533cdcfae 100644 --- a/tests/test_evaluation/test_metrics/test_recog_metric.py +++ b/tests/test_evaluation/test_metrics/test_recog_metric.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy import unittest from mmengine.data import LabelData @@ -11,70 +10,56 @@ class TestWordMetric(unittest.TestCase): def setUp(self): - # prepare gt hello HELLO $HELLO$ - gt1 = { - 'data_sample': { - 'height': 32, - 'width': 100, - 'instances': [{ - 'text': 'hello' - }] - } - } - gt2 = { - 'data_sample': { - 'height': 32, - 'width': 100, - 'instances': [{ - 'text': 'HELLO' - }] - } - } - gt3 = { - 'data_sample': { - 'height': 32, - 'width': 100, - 'instances': [{ - 'text': '$HELLO$' - }] - } - } - self.gt = [gt1, gt2, gt3] - # prepare pred - pred_data_sample = TextRecogDataSample() + + self.pred = [] + data_sample = TextRecogDataSample() pred_text = LabelData() pred_text.item = 'hello' - pred_data_sample.pred_text = pred_text - - self.pred = [ - pred_data_sample, - copy.deepcopy(pred_data_sample), - copy.deepcopy(pred_data_sample), - ] + data_sample.pred_text = pred_text + gt_text = LabelData() + gt_text.item = 'hello' + data_sample.gt_text = gt_text + self.pred.append(data_sample) + data_sample = TextRecogDataSample() + pred_text = LabelData() + pred_text.item = 'hello' + data_sample.pred_text = pred_text + gt_text = LabelData() + gt_text.item = 'HELLO' + data_sample.gt_text = gt_text + self.pred.append(data_sample) + data_sample = TextRecogDataSample() + pred_text = LabelData() + pred_text.item = 'hello' + data_sample.pred_text = pred_text + gt_text = LabelData() + gt_text.item = '$HELLO$' + data_sample.gt_text = gt_text + self.pred.append(data_sample) def test_word_acc_metric(self): metric = WordMetric(mode='exact') - metric.process(self.gt, self.pred) + metric.process(None, self.pred) eval_res = metric.evaluate(size=3) self.assertAlmostEqual(eval_res['recog/word_acc'], 1. / 3, 4) def test_word_acc_ignore_case_metric(self): metric = WordMetric(mode='ignore_case') - metric.process(self.gt, self.pred) + metric.process(None, self.pred) eval_res = metric.evaluate(size=3) self.assertAlmostEqual(eval_res['recog/word_acc_ignore_case'], 2. / 3, 4) def test_word_acc_ignore_case_symbol_metric(self): metric = WordMetric(mode='ignore_case_symbol') - metric.process(self.gt, self.pred) + metric.process(None, self.pred) eval_res = metric.evaluate(size=3) self.assertEqual(eval_res['recog/word_acc_ignore_case_symbol'], 1.0) def test_all_metric(self): metric = WordMetric( mode=['exact', 'ignore_case', 'ignore_case_symbol']) - metric.process(self.gt, self.pred) + metric.process(None, self.pred) eval_res = metric.evaluate(size=3) self.assertAlmostEqual(eval_res['recog/word_acc'], 1. / 3, 4) self.assertAlmostEqual(eval_res['recog/word_acc_ignore_case'], 2. / 3, @@ -85,42 +70,27 @@ def test_all_metric(self): class TestCharMetric(unittest.TestCase): def setUp(self): - # prepare gt - gt1 = { - 'data_sample': { - 'height': 32, - 'width': 100, - 'instances': [{ - 'text': 'hello' - }] - } - } - gt2 = { - 'data_sample': { - 'height': 32, - 'width': 100, - 'instances': [{ - 'text': 'HELLO' - }] - } - } - self.gt = [gt1, gt2] - # prepare pred - pred_data_sample1 = TextRecogDataSample() + self.pred = [] + data_sample = TextRecogDataSample() pred_text = LabelData() pred_text.item = 'helL' - pred_data_sample1.pred_text = pred_text - - pred_data_sample2 = TextRecogDataSample() + data_sample.pred_text = pred_text + gt_text = LabelData() + gt_text.item = 'hello' + data_sample.gt_text = gt_text + self.pred.append(data_sample) + data_sample = TextRecogDataSample() pred_text = LabelData() pred_text.item = 'HEL' - pred_data_sample2.pred_text = pred_text - - self.pred = [pred_data_sample1, pred_data_sample2] + data_sample.pred_text = pred_text + gt_text = LabelData() + gt_text.item = 'HELLO' + data_sample.gt_text = gt_text + self.pred.append(data_sample) def test_char_recall_precision_metric(self): metric = CharMetric() - metric.process(self.gt, self.pred) + metric.process(None, self.pred) eval_res = metric.evaluate(size=2) self.assertEqual(eval_res['recog/char_recall'], 0.7) self.assertEqual(eval_res['recog/char_precision'], 1) @@ -129,41 +99,26 @@ def test_char_recall_precision_metric(self): class TestOneMinusNED(unittest.TestCase): def setUp(self): - # prepare gt - gt1 = { - 'data_sample': { - 'height': 32, - 'width': 100, - 'instances': [{ - 'text': 'hello' - }] - } - } - gt2 = { - 'data_sample': { - 'height': 32, - 'width': 100, - 'instances': [{ - 'text': 'HELLO' - }] - } - } - self.gt = [gt1, gt2] - # prepare pred - pred_data_sample1 = TextRecogDataSample() + self.pred = [] + data_sample = TextRecogDataSample() pred_text = LabelData() pred_text.item = 'pred_helL' - pred_data_sample1.pred_text = pred_text - - pred_data_sample2 = TextRecogDataSample() + data_sample.pred_text = pred_text + gt_text = LabelData() + gt_text.item = 'hello' + data_sample.gt_text = gt_text + self.pred.append(data_sample) + data_sample = TextRecogDataSample() pred_text = LabelData() pred_text.item = 'HEL' - pred_data_sample2.pred_text = pred_text - - self.pred = [pred_data_sample1, pred_data_sample2] + data_sample.pred_text = pred_text + gt_text = LabelData() + gt_text.item = 'HELLO' + data_sample.gt_text = gt_text + self.pred.append(data_sample) def test_one_minus_ned_metric(self): metric = OneMinusNEDMetric() - metric.process(self.gt, self.pred) + metric.process(None, self.pred) eval_res = metric.evaluate(size=2) self.assertEqual(eval_res['recog/1-N.E.D'], 0.4875) diff --git a/tests/test_visualization/test_textdet_visualizer.py b/tests/test_visualization/test_textdet_visualizer.py index f21b59ae1..b311a56ea 100644 --- a/tests/test_visualization/test_textdet_visualizer.py +++ b/tests/test_visualization/test_textdet_visualizer.py @@ -20,25 +20,22 @@ def setUp(self): self.image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') # gt_instances - gt_det_data_sample = TextDetDataSample() + data_sample = TextDetDataSample() gt_instances_data = dict( bboxes=self._rand_bboxes(5, h, w), polygons=self._rand_polys(5, h, w), labels=torch.zeros(5, )) gt_instances = InstanceData(**gt_instances_data) - gt_det_data_sample.gt_instances = gt_instances - self.gt_det_data_sample = gt_det_data_sample + data_sample.gt_instances = gt_instances - # pred_instances - pred_det_data_sample = TextDetDataSample() pred_instances_data = dict( bboxes=self._rand_bboxes(5, h, w), polygons=self._rand_polys(5, h, w), labels=torch.zeros(5, ), scores=torch.rand((5, ))) pred_instances = InstanceData(**pred_instances_data) - pred_det_data_sample.pred_instances = pred_instances - self.pred_det_data_sample = pred_det_data_sample + data_sample.pred_instances = pred_instances + self.data_sample = data_sample def test_text_det_local_visualizer(self): for with_poly in [True, False]: @@ -68,32 +65,30 @@ def _rand_polys(self, num_bboxes, h, w): def _test_add_datasample(self, vis_cfg): image = self.image h, w, c = image.shape - gt_det_data_sample = self.gt_det_data_sample - pred_det_data_sample = self.pred_det_data_sample det_local_visualizer = TextDetLocalVisualizer(**vis_cfg) - det_local_visualizer.add_datasample('image', image, gt_det_data_sample) + det_local_visualizer.add_datasample('image', image, self.data_sample) with tempfile.TemporaryDirectory() as tmp_dir: # test out out_file = osp.join(tmp_dir, 'out_file.jpg') det_local_visualizer.add_datasample( - 'image', image, gt_det_data_sample, out_file=out_file) + 'image', + image, + self.data_sample, + out_file=out_file, + draw_gt=False, + draw_pred=False) self._assert_image_and_shape(out_file, (h, w, c)) det_local_visualizer.add_datasample( - 'image', - image, - gt_det_data_sample, - pred_det_data_sample, - out_file=out_file) + 'image', image, self.data_sample, out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 2, c)) det_local_visualizer.add_datasample( 'image', image, - gt_det_data_sample, - pred_det_data_sample, + self.data_sample, draw_gt=False, out_file=out_file) self._assert_image_and_shape(out_file, (h, w, c)) @@ -101,8 +96,7 @@ def _test_add_datasample(self, vis_cfg): det_local_visualizer.add_datasample( 'image', image, - gt_det_data_sample, - pred_det_data_sample, + self.data_sample, draw_pred=False, out_file=out_file) self._assert_image_and_shape(out_file, (h, w, c)) diff --git a/tests/test_visualization/test_textrecog_visualizer.py b/tests/test_visualization/test_textrecog_visualizer.py index 01e35bb6c..81080507f 100644 --- a/tests/test_visualization/test_textrecog_visualizer.py +++ b/tests/test_visualization/test_textrecog_visualizer.py @@ -18,21 +18,19 @@ def test_add_datasample(self): image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') # test gt_text - gt_recog_data_sample = TextRecogDataSample() + data_sample = TextRecogDataSample() img_meta = dict(img_shape=(12, 10, 3)) gt_text = LabelData(metainfo=img_meta) gt_text.item = 'mmocr' - gt_recog_data_sample.gt_text = gt_text + data_sample.gt_text = gt_text recog_local_visualizer = TextRecogLocalVisualizer() - recog_local_visualizer.add_datasample('image', image, - gt_recog_data_sample) + recog_local_visualizer.add_datasample('image', image, data_sample) # test gt_text and pred_text - pred_recog_data_sample = TextRecogDataSample() pred_text = LabelData(metainfo=img_meta) pred_text.item = 'MMOCR' - pred_recog_data_sample.pred_text = pred_text + data_sample.pred_text = pred_text with tempfile.TemporaryDirectory() as tmp_dir: # test out @@ -40,26 +38,27 @@ def test_add_datasample(self): # draw_gt = True + gt_sample recog_local_visualizer.add_datasample( - 'image', image, gt_recog_data_sample, out_file=out_file) + 'image', + image, + data_sample, + out_file=out_file, + draw_gt=True, + draw_pred=False) self._assert_image_and_shape(out_file, (h * 2, w, 3)) # draw_gt = True + gt_sample + pred_sample recog_local_visualizer.add_datasample( 'image', image, - gt_recog_data_sample, - pred_recog_data_sample, - out_file=out_file) + data_sample, + out_file=out_file, + draw_gt=True, + draw_pred=True) self._assert_image_and_shape(out_file, (h * 3, w, 3)) # draw_gt = False + gt_sample + pred_sample recog_local_visualizer.add_datasample( - 'image', - image, - gt_recog_data_sample, - pred_recog_data_sample, - draw_gt=False, - out_file=out_file) + 'image', image, data_sample, draw_gt=False, out_file=out_file) self._assert_image_and_shape(out_file, (h * 2, w, 3)) def _assert_image_and_shape(self, out_file, out_shape):