Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add bleu metric and its test file #66

Merged
merged 16 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
- run:
name: Install OneFlow via pip
command: |
pip install -f https://release.oneflow.info oneflow==0.8.0+cpu
pip install -f https://release.oneflow.info oneflow==0.8.0+cpu "numpy<1.24.0"
- run:
name: Install mmeval and dependencies
command: |
Expand Down Expand Up @@ -135,7 +135,7 @@ jobs:
python -V
pip install torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html
pip install paddlepaddle-gpu==2.3.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
pip install --pre oneflow -f https://staging.oneflow.info/branch/master/cu112
pip install --pre oneflow -f https://staging.oneflow.info/branch/master/cu112 "numpy<1.24.0"
- run:
name: Install mmeval and dependencies
command: |
Expand Down
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ Metrics
PSNR
MAE
MSE
BLEU
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ Metrics
PSNR
MAE
MSE
BLEU
3 changes: 2 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .accuracy import Accuracy
from .ava_map import AVAMeanAP
from .bleu import BLEU
from .coco_detection import COCODetectionMetric
from .end_point_error import EndPointError
from .f_metric import F1Metric
Expand All @@ -24,5 +25,5 @@
'F1Metric', 'HmeanIoU', 'SingleLabelMetric', 'COCODetectionMetric',
'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall',
'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric',
'AveragePrecision', 'AVAMeanAP'
'AveragePrecision', 'AVAMeanAP', 'BLEU'
]
192 changes: 192 additions & 0 deletions mmeval/metrics/bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This class is modified from `torchmetrics
# <https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/bleu.py>`_.
import numpy as np
from collections import Counter
from typing import List, Optional, Sequence, Tuple

from mmeval import BaseMetric


def get_n_gram(token: Sequence[str], n_gram: int) -> Counter:
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
"""A function get n_gram of sentences.

Args:
token (Sequence[str]): A series of tokens about sentences.
n_gram (int): The maximum number of words contained in a phrase
when calculating word fragments. Defaults to 4.

Returns:
Counter: The n_gram contained in sentences with Counter format.
"""
counter: Counter = Counter()
for i in range(1, n_gram + 1):
for j in range(len(token) - i + 1):
key = tuple(token[j:(i + j)])
counter[key] += 1
return counter


def tokenizer_fn(sentence: str) -> List[str]:
"""This function is used to segment a sentence.

Args:
sentence (str): A sentence.

Returns:
List[str]: A list of tokens after word segmentation.
"""
return sentence.split()


def _get_brevity_penalty(pred_len: np.array,
references_len: np.array) -> np.array:
"""This function is used to calculate penalty factor.

Args:
pred_len (np.array): number of grams in the predicted sentence.
references_len (np.array): number of grams in the references.

Returns:
np.array: penalty factor.
"""
if pred_len > references_len:
return np.array(1.)
return np.array(np.exp(1 - references_len / pred_len))


class BLEU(BaseMetric):
"""Bilingual Evaluation Understudy metric.

This metric proposed in `BLEU: a Method for Automatic Evaluation of Machine Translation
<https://aclanthology.org/P02-1040.pdf>`_ is a tool for evaluating the quality of machine translation.
The closer the translation is to human translation,
the higher the score will be.

Args:
n_gram (int): The maximum number of words contained in a phrase
when calculating word fragments. Defaults to 4.
smooth (bool): Whether or not to apply to smooth. Defaults to False.
ngram_weights(Sequence[float], optional): Weights used
for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used. Defaults to None.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.

Examples:
>>> from mmeval import BLEU
>>> predictions = ['the cat is on the mat', 'There is a big tree near the park here'] # noqa: E501
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
>>> references = [['a cat is on the mat'], ['A big tree is growing near the park here']] # noqa: E501
>>> bleu = BLEU()
>>> bleu_results = bleu(predictions, references)
{'bleu': 0.5226045319355426}

>>> # Calculate BLEU with smooth:
>>> from mmeval import BLEU
>>> predictions = ['the cat is on the mat', 'There is a big tree near the park here'] # noqa: E501
>>> references = [['a cat is on the mat'], ['A big tree is growing near the park here']] # noqa: E501
>>> bleu = BLEU(smooth = True)
>>> bleu_results = bleu(predictions, references)
{'bleu': 0.566315716093867}
"""

def __init__(self,
n_gram: int = 4,
smooth: bool = False,
ngram_weights: Optional[Sequence[float]] = None,
**kwargs) -> None:
super().__init__(**kwargs)
self.n_gram = n_gram
self.smooth = smooth
if ngram_weights is not None and len(ngram_weights) != n_gram:
raise ValueError(
'The length of ngram_weights is not equal to `n_gram`: '
f'{len(ngram_weights)} != {n_gram}')
if ngram_weights is None:
ngram_weights = [1.0 / n_gram] * n_gram
self.ngram_weights = ngram_weights

def add(self, predictions: Sequence[str], references: Sequence[Sequence[str]]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Add the intermediate results to ``self._results``.

Args:
predictions (Sequence[str]): An iterable of machine
translated corpus.
references (Sequence[Sequence[str]]): An iterable of
iterables of reference corpus.
"""

references_token: Sequence[Sequence[Sequence[str]]] = [
[tokenizer_fn(line) for line in r] for r in references
]
predictions_token: Sequence[Sequence[str]] = [
tokenizer_fn(line) for line in predictions
]
for prediction, references in zip(predictions_token, references_token):
pred_len = len(prediction)
# Find the reference that is closest in length to the prediction
references_len = len(
min(references, key=lambda x: abs(len(x) - pred_len)))

pred_counter: Counter = get_n_gram(prediction, self.n_gram)
reference_counter: Counter = Counter()
for reference in references:
# Take union for the n_gram of references.
reference_counter |= get_n_gram(reference, self.n_gram)

# Take the intersection of n_gram of prediction and references.
counter_clip = pred_counter & reference_counter
precision_matches = np.zeros(self.n_gram)
precision_total = np.zeros(self.n_gram)
for counter in counter_clip:
precision_matches[len(counter) - 1] += counter_clip[counter]
for counter in pred_counter:
precision_total[len(counter) - 1] += pred_counter[counter]

result = (pred_len, references_len, precision_matches,
precision_total)
self._results.append(result)

def compute_metric(
self, results: List[Tuple[int, int, np.ndarray,
np.ndarray]]) -> dict:
"""Compute the bleu metric.

This method would be invoked in ``BaseMetric.compute`` after
distributed synchronization.

Args:
results (List[Tuple[int, int, np.ndarray, np.ndarray]]):
A list that consisting the tuple of correct numbers.
Tuple contains pred_len, references_len,
precision_matches, precision_total.
This list has already been synced across all ranks.

Returns:
Dict[str, float]: The computed bleu score.
"""
pred_len = 0
references_len = 0
precision_matches = np.zeros(self.n_gram)
precision_total = np.zeros(self.n_gram)
for result in results:
pred_len += result[0]
references_len += result[1]
precision_matches += result[2]
precision_total += result[3]

if min(precision_matches) == 0.0:
return {'bleu': 0.0}

if self.smooth:
precision_score = np.add(precision_matches, np.ones(
self.n_gram)) / np.add(precision_total, np.ones(self.n_gram))
precision_score[0] = precision_matches[0] / precision_total[0]
else:
precision_score = precision_matches / precision_total

precision_score = np.array(
self.ngram_weights) * np.log(precision_score)
brevity_penalty = _get_brevity_penalty(pred_len, references_len)
bleu = brevity_penalty * np.exp(np.sum(precision_score))
result = {'bleu': float(bleu)}
return result
96 changes: 96 additions & 0 deletions tests/test_metrics/test_bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
from collections import Counter

from mmeval.metrics import BLEU
from mmeval.metrics.bleu import get_n_gram


@pytest.mark.parametrize('n_gram', [2, 4])
def test_get_n_gram(n_gram):
token = ['a', 'cat', 'is', 'on', 'the', 'mat']
result = get_n_gram(token, n_gram)
if n_gram == 2:
counter = Counter({
('a', ): 1,
('cat', ): 1,
('is', ): 1,
('on', ): 1,
('the', ): 1,
('mat', ): 1,
('a', 'cat'): 1,
('cat', 'is'): 1,
('is', 'on'): 1,
('on', 'the'): 1,
('the', 'mat'): 1
})
else:
counter = Counter({
('a', ): 1,
('cat', ): 1,
('is', ): 1,
('on', ): 1,
('the', ): 1,
('mat', ): 1,
('a', 'cat'): 1,
('cat', 'is'): 1,
('is', 'on'): 1,
('on', 'the'): 1,
('the', 'mat'): 1,
('a', 'cat', 'is'): 1,
('cat', 'is', 'on'): 1,
('is', 'on', 'the'): 1,
('on', 'the', 'mat'): 1,
('a', 'cat', 'is', 'on'): 1,
('cat', 'is', 'on', 'the'): 1,
('is', 'on', 'the', 'mat'): 1
})
assert result == counter


def test_bleu():
predictions = [
'the cat is on the mat',
'There is a big tree near the park here',
'The sun rises from the northeast with sunshine',
'I was late for work today for the rainy',
'My name is Barry',
]
references = [['a cat is on the mat'],
['A big tree is growing near the park here'],
['A fierce sun rises in the northeast with sunshine'],
['I went to work too late today for the rainy'],
['I am Barry']]

bleu = BLEU()
for i in range(len(predictions)):
bleu.add([predictions[i]], [references[i]])
bleu_results = bleu.compute()
assert isinstance(bleu_results, dict)
np.testing.assert_almost_equal(bleu_results['bleu'], 0.4006741601366701)

bleu = BLEU(smooth=True)
bleu_results = bleu(predictions, references)
assert isinstance(bleu_results, dict)
np.testing.assert_almost_equal(bleu_results['bleu'], 0.42504770796962527)


@pytest.mark.parametrize('n_gram', [1, 2, 3, 4])
def test_bleu_ngram(n_gram):
predictions = [
'the cat is on the mat',
'There is a big tree near the park here',
'The sun rises from the northeast with sunshine',
'I was late for work today for the rainy',
'My name is Barry',
]
references = [['a cat is on the mat'],
['A big tree is growing near the park here'],
['A fierce sun rises in the northeast with sunshine'],
['I went to work too late today for the rainy'],
['I am Barry']]

bleu = BLEU(n_gram=n_gram)
bleu_results = bleu(predictions, references)
assert isinstance(bleu_results, dict)