Skip to content

Commit

Permalink
Add lowercase super parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
go-with-me000 committed Jan 29, 2023
1 parent d7846cd commit 07e20c6
Show file tree
Hide file tree
Showing 11 changed files with 552 additions and 89 deletions.
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ Metrics
MattingMSE
ConnectivityError
DOTAMeanAP
ROUGE
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ Metrics
MattingMSE
ConnectivityError
DOTAMeanAP
ROUGE
3 changes: 2 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .pck_accuracy import JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy
from .proposal_recall import ProposalRecall
from .psnr import PSNR
from .rouge import ROUGE
from .sad import SAD
from .single_label import SingleLabelMetric
from .snr import SNR
Expand All @@ -31,5 +32,5 @@
'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall',
'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric',
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', 'SAD',
'GradientError', 'MattingMSE', 'ConnectivityError'
'GradientError', 'MattingMSE', 'ConnectivityError', 'ROUGE'
]
68 changes: 28 additions & 40 deletions mmeval/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,10 @@
# <https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/bleu.py>`_.
import numpy as np
from collections import Counter
from typing import List, Optional, Sequence, Tuple
from typing import Callable, List, Optional, Sequence, Tuple, Union

from mmeval import BaseMetric


def get_n_gram(token: Sequence[str], n_gram: int) -> Counter:
"""A function get n_gram of sentences.
Args:
token (Sequence[str]): A series of tokens about sentences.
n_gram (int): The maximum number of words contained in a phrase
when calculating word fragments. Defaults to 4.
Returns:
Counter: The n_gram contained in sentences with Counter format.
"""
counter: Counter = Counter()
for i in range(1, n_gram + 1):
for j in range(len(token) - i + 1):
key = tuple(token[j:(i + j)])
counter[key] += 1
return counter


def tokenizer_fn(sentence: str) -> List[str]:
"""This function is used to segment a sentence.
Args:
sentence (str): A sentence.
Returns:
List[str]: A list of tokens after word segmentation.
"""
return sentence.split()
from mmeval.metrics.utils import get_n_gram, get_tokenizer, infer_language


def _get_brevity_penalty(pred_len: np.array,
Expand Down Expand Up @@ -67,9 +37,12 @@ class BLEU(BaseMetric):
n_gram (int): The maximum number of words contained in a phrase
when calculating word fragments. Defaults to 4.
smooth (bool): Whether or not to apply to smooth. Defaults to False.
ngram_weights(Sequence[float], optional): Weights used
ngram_weights (Sequence[float], optional): Weights used
for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used. Defaults to None.
tokenizer_fn (Union[Callable, str, None]): A user's own tokenizer function.
Defaults to None.
New in version 0.3.0.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
Examples:
Expand All @@ -93,6 +66,7 @@ def __init__(self,
n_gram: int = 4,
smooth: bool = False,
ngram_weights: Optional[Sequence[float]] = None,
tokenizer_fn: Union[Callable, str, None] = None,
**kwargs) -> None:
super().__init__(**kwargs)
self.n_gram = n_gram
Expand All @@ -105,21 +79,35 @@ def __init__(self,
ngram_weights = [1.0 / n_gram] * n_gram
self.ngram_weights = ngram_weights

# Select tokenizer according to the entered value.
self.tokenizer_fn = None
if callable(tokenizer_fn):
self.tokenizer_fn = tokenizer_fn
elif isinstance(tokenizer_fn, str):
self.tokenizer_fn = get_tokenizer(tokenizer_fn)
if self.tokenizer_fn is None:
raise ValueError('Right now, `tokenizer_fn` only supports '
"pre-defined 'en' or 'cn'.")
else:
assert tokenizer_fn is None, \
f'`tokenizer_fn` supports Callable, str or None, but not `{type(tokenizer_fn)}`' # noqa: E501

def add(self, predictions: Sequence[str], references: Sequence[Sequence[str]]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Add the intermediate results to ``self._results``.
Args:
predictions (Sequence[str]): An iterable of machine
translated corpus.
references (Sequence[Sequence[str]]): An iterable of
iterables of reference corpus.
predictions (Sequence[str]): An iterable of predicted sentences.
references (Sequence[Sequence[str]): An iterable of
referenced sentences.
"""

if self.tokenizer_fn is None:
language = infer_language(predictions[0])
self.tokenizer_fn = get_tokenizer(language)
references_token: Sequence[Sequence[Sequence[str]]] = [
[tokenizer_fn(line) for line in r] for r in references
[self.tokenizer_fn(line) for line in r] for r in references
]
predictions_token: Sequence[Sequence[str]] = [
tokenizer_fn(line) for line in predictions
self.tokenizer_fn(line) for line in predictions
]
for prediction, references in zip(predictions_token, references_token):
pred_len = len(prediction)
Expand Down
Loading

0 comments on commit 07e20c6

Please sign in to comment.