Skip to content

Commit

Permalink
add rouge metric
Browse files Browse the repository at this point in the history
  • Loading branch information
go-with-me000 committed Jan 3, 2023
1 parent 04d04ce commit 961270f
Show file tree
Hide file tree
Showing 8 changed files with 633 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ Metrics
MAE
MSE
BLEU
ROUGE
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ Metrics
MAE
MSE
BLEU
ROUGE
3 changes: 2 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .pck_accuracy import JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy
from .proposal_recall import ProposalRecall
from .psnr import PSNR
from .rouge import ROUGE
from .single_label import SingleLabelMetric
from .snr import SNR
from .ssim import SSIM
Expand All @@ -25,5 +26,5 @@
'F1Metric', 'HmeanIoU', 'SingleLabelMetric', 'COCODetectionMetric',
'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall',
'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric',
'AveragePrecision', 'AVAMeanAP', 'BLEU'
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'ROUGE'
]
81 changes: 61 additions & 20 deletions mmeval/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# <https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/bleu.py>`_.
import numpy as np
from collections import Counter
from typing import List, Optional, Sequence, Tuple
from typing import Callable, List, Optional, Sequence, Tuple, Union

from mmeval import BaseMetric

Expand All @@ -27,16 +27,14 @@ def get_n_gram(token: Sequence[str], n_gram: int) -> Counter:
return counter


def tokenizer_fn(sentence: str) -> List[str]:
"""This function is used to segment a sentence.
Args:
sentence (str): A sentence.
Returns:
List[str]: A list of tokens after word segmentation.
"""
return sentence.split()
def _get_tokenizer(lang):
"""A function to choose tokenizer."""
if lang == 'en':
return str.split
elif lang in ('cn', 'zh'):
return list
else:
return None


def _get_brevity_penalty(pred_len: np.array,
Expand Down Expand Up @@ -67,9 +65,11 @@ class BLEU(BaseMetric):
n_gram (int): The maximum number of words contained in a phrase
when calculating word fragments. Defaults to 4.
smooth (bool): Whether or not to apply to smooth. Defaults to False.
ngram_weights(Sequence[float], optional): Weights used
ngram_weights (Sequence[float], optional): Weights used
for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used. Defaults to None.
tokenizer (Union[Callable, str, None]): A user's own tokenizer function.
Defaults to None.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
Examples:
Expand All @@ -93,6 +93,7 @@ def __init__(self,
n_gram: int = 4,
smooth: bool = False,
ngram_weights: Optional[Sequence[float]] = None,
tokenizer_fn: Union[Callable, str, None] = None,
**kwargs) -> None:
super().__init__(**kwargs)
self.n_gram = n_gram
Expand All @@ -105,21 +106,61 @@ def __init__(self,
ngram_weights = [1.0 / n_gram] * n_gram
self.ngram_weights = ngram_weights

def add(self, predictions: Sequence[str], references: Sequence[Sequence[str]]) -> None: # type: ignore # yapf: disable # noqa: E501
# Select tokenizer according to the entered value.
self.tokenizer_fn = None
if callable(tokenizer_fn):
self.tokenizer_fn = tokenizer_fn
elif isinstance(tokenizer_fn, str):
self.tokenizer_fn = _get_tokenizer(tokenizer_fn)
if self.tokenizer_fn is None:
raise ValueError(
"Right now, `tokenizer_fn` only supports pre-defined 'en' or 'cn'." # noqa: E501
)
else:
assert tokenizer_fn is None, \
f'`tokenizer_fn` supports Callable, str or None, but not `{type(tokenizer_fn)}`' # noqa: E501

def add(self, predictions: Union[str, Sequence[str]], references: Union[str, Sequence[str], Sequence[Sequence[str]]]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Add the intermediate results to ``self._results``.
Args:
predictions (Sequence[str]): An iterable of machine
translated corpus.
references (Sequence[Sequence[str]]): An iterable of
iterables of reference corpus.
predictions (Union[Sequence[str] | str]): An iterable of
predicted sentences or a single predicted sentence.
references (Union[Sequence[str] | str]): An iterable of
referenced sentences or an iterable of target sentences
or a single target sentence.
"""

if isinstance(references, list) and all(
isinstance(reference, str) for reference in references):
if isinstance(predictions, str):
references = [references]
else:
if len(predictions) == 1:
references = [references]
else:
references = [[reference] for reference in references]

if isinstance(predictions, str):
predictions = [predictions]

if isinstance(references, str):
references = [[references]]
assert len(predictions) == len(
references
), 'The number of predictions and references must be equal'

if self.tokenizer_fn is None:
lang = 'en'
for _char in predictions[0]:
if '\u4e00' <= _char <= '\u9fa5':
lang = 'cn'
break
self.tokenizer_fn = _get_tokenizer(lang)
references_token: Sequence[Sequence[Sequence[str]]] = [
[tokenizer_fn(line) for line in r] for r in references
[self.tokenizer_fn(line) for line in r] for r in references
]
predictions_token: Sequence[Sequence[str]] = [
tokenizer_fn(line) for line in predictions
self.tokenizer_fn(line) for line in predictions
]
for prediction, references in zip(predictions_token, references_token):
pred_len = len(prediction)
Expand Down
Loading

0 comments on commit 961270f

Please sign in to comment.