Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add ROUGE to mmeval #72

Merged
merged 4 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config-zh-cn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
- repo: https://gitee.com/openmmlab/mirrors-isort
rev: 5.10.1
- repo: https://gitee.com/zhouzaida/mirrors-isort
rev: 5.12.1
hooks:
- id: isort
- repo: https://gitee.com/openmmlab/mirrors-yapf
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
- repo: https://github.com/zhouzaida/isort
rev: 5.12.1
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
Expand Down
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ Metrics
MattingMSE
ConnectivityError
DOTAMeanAP
ROUGE
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ Metrics
MattingMSE
ConnectivityError
DOTAMeanAP
ROUGE
3 changes: 2 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .pck_accuracy import JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy
from .proposal_recall import ProposalRecall
from .psnr import PSNR
from .rouge import ROUGE
from .sad import SAD
from .single_label import SingleLabelMetric
from .snr import SNR
Expand All @@ -31,5 +32,5 @@
'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall',
'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric',
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', 'SAD',
'GradientError', 'MattingMSE', 'ConnectivityError'
'GradientError', 'MattingMSE', 'ConnectivityError', 'ROUGE'
]
68 changes: 28 additions & 40 deletions mmeval/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,10 @@
# <https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/bleu.py>`_.
import numpy as np
from collections import Counter
from typing import List, Optional, Sequence, Tuple
from typing import Callable, List, Optional, Sequence, Tuple, Union

from mmeval import BaseMetric


def get_n_gram(token: Sequence[str], n_gram: int) -> Counter:
"""A function get n_gram of sentences.

Args:
token (Sequence[str]): A series of tokens about sentences.
n_gram (int): The maximum number of words contained in a phrase
when calculating word fragments. Defaults to 4.

Returns:
Counter: The n_gram contained in sentences with Counter format.
"""
counter: Counter = Counter()
for i in range(1, n_gram + 1):
for j in range(len(token) - i + 1):
key = tuple(token[j:(i + j)])
counter[key] += 1
return counter


def tokenizer_fn(sentence: str) -> List[str]:
"""This function is used to segment a sentence.

Args:
sentence (str): A sentence.

Returns:
List[str]: A list of tokens after word segmentation.
"""
return sentence.split()
from mmeval.metrics.utils import get_n_gram, get_tokenizer, infer_language


def _get_brevity_penalty(pred_len: np.array,
Expand Down Expand Up @@ -67,9 +37,12 @@ class BLEU(BaseMetric):
n_gram (int): The maximum number of words contained in a phrase
when calculating word fragments. Defaults to 4.
smooth (bool): Whether or not to apply to smooth. Defaults to False.
ngram_weights(Sequence[float], optional): Weights used
ngram_weights (Sequence[float], optional): Weights used
for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used. Defaults to None.
tokenizer_fn (Union[Callable, str, None]): A user's own tokenizer function.
Defaults to None.
Comment on lines +43 to +44
Copy link
Collaborator

@zhouzaida zhouzaida Jan 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tokenizer_fn (Union[Callable, str, None]): A user's own tokenizer function.
Defaults to None.
tokenizer_fn (Callable or str, optional): A user's own tokenizer function.
Defaults to None.
New in version 0.3.0.

New in version 0.3.0.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.

Examples:
Expand All @@ -93,6 +66,7 @@ def __init__(self,
n_gram: int = 4,
smooth: bool = False,
ngram_weights: Optional[Sequence[float]] = None,
tokenizer_fn: Union[Callable, str, None] = None,
**kwargs) -> None:
super().__init__(**kwargs)
self.n_gram = n_gram
Expand All @@ -105,21 +79,35 @@ def __init__(self,
ngram_weights = [1.0 / n_gram] * n_gram
self.ngram_weights = ngram_weights

# Select tokenizer according to the entered value.
self.tokenizer_fn = None
if callable(tokenizer_fn):
self.tokenizer_fn = tokenizer_fn
elif isinstance(tokenizer_fn, str):
self.tokenizer_fn = get_tokenizer(tokenizer_fn)
if self.tokenizer_fn is None:
raise ValueError('Right now, `tokenizer_fn` only supports '
"pre-defined 'en' or 'cn'.")
else:
assert tokenizer_fn is None, \
f'`tokenizer_fn` supports Callable, str or None, but not `{type(tokenizer_fn)}`' # noqa: E501

def add(self, predictions: Sequence[str], references: Sequence[Sequence[str]]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Add the intermediate results to ``self._results``.

Args:
predictions (Sequence[str]): An iterable of machine
translated corpus.
references (Sequence[Sequence[str]]): An iterable of
iterables of reference corpus.
predictions (Sequence[str]): An iterable of predicted sentences.
references (Sequence[Sequence[str]): An iterable of
referenced sentences.
"""

if self.tokenizer_fn is None:
language = infer_language(predictions[0])
self.tokenizer_fn = get_tokenizer(language)
references_token: Sequence[Sequence[Sequence[str]]] = [
[tokenizer_fn(line) for line in r] for r in references
[self.tokenizer_fn(line) for line in r] for r in references
]
predictions_token: Sequence[Sequence[str]] = [
tokenizer_fn(line) for line in predictions
self.tokenizer_fn(line) for line in predictions
]
for prediction, references in zip(predictions_token, references_token):
pred_len = len(prediction)
Expand Down
Loading