diff --git a/mmeval/metrics/char_recall_precision.py b/mmeval/metrics/char_recall_precision.py index 25c6cf4e..e7cf780b 100644 --- a/mmeval/metrics/char_recall_precision.py +++ b/mmeval/metrics/char_recall_precision.py @@ -14,14 +14,14 @@ class CharRecallPrecision(BaseMetric): - unchanged: Do not change prediction texts and labels. - upper: Convert prediction texts and labels into uppercase - characters. + characters. - lower: Convert prediction texts and labels into lowercase - characters. + characters. Usually, it only works for English characters. Defaults to 'unchanged'. invalid_symbol (str): A regular expression to filter out invalid or - not cared characters. Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'. **kwargs: Keyword parameters passed to :class:`BaseMetric`. Examples: @@ -36,21 +36,21 @@ class CharRecallPrecision(BaseMetric): def __init__(self, letter_case: str = 'unchanged', - invalid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]', **kwargs): super().__init__(**kwargs) assert letter_case in ['unchanged', 'upper', 'lower'] self.letter_case = letter_case self.invalid_symbol = re.compile(invalid_symbol) - def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501 + def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501 """Process one batch of data and predictions. Args: predictions (list[str]): The prediction texts. - labels (list[str]): The ground truth texts. + groundtruths (list[str]): The ground truth texts. """ - for pred, label in zip(predictions, labels): + for pred, label in zip(predictions, groundtruths): if self.letter_case in ['upper', 'lower']: pred = getattr(pred, self.letter_case)() label = getattr(label, self.letter_case)() @@ -79,10 +79,10 @@ def compute_metric(self, results: Sequence[Tuple[int, int, int]]) -> Dict: true_positive_sum += true_positive char_recall = true_positive_sum / max(gt_sum, 1.0) char_precision = true_positive_sum / max(pred_sum, 1.0) - eval_res = {} - eval_res['recall'] = char_recall - eval_res['precision'] = char_precision - return eval_res + metric_results = {} + metric_results['recall'] = char_recall + metric_results['precision'] = char_precision + return metric_results def _cal_true_positive_char(self, pred: str, gt: str) -> int: """Calculate correct character number in prediction.