Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add bleu metric and its test file #66

Merged
merged 16 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ Metrics
PSNR
MAE
MSE
Bleu
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ Metrics
PSNR
MAE
MSE
Bleu
3 changes: 2 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .accuracy import Accuracy
from .ava_map import AVAMeanAP
from .bleu import Bleu
from .coco_detection import COCODetectionMetric
from .end_point_error import EndPointError
from .f_metric import F1Metric
Expand All @@ -24,5 +25,5 @@
'F1Metric', 'HmeanIoU', 'SingleLabelMetric', 'COCODetectionMetric',
'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall',
'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric',
'AveragePrecision', 'AVAMeanAP'
'AveragePrecision', 'AVAMeanAP', 'Bleu'
]
193 changes: 193 additions & 0 deletions mmeval/metrics/bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from collections import Counter
from typing import List, Optional, Sequence, Tuple

from mmeval import BaseMetric


def get_n_gram(token: Sequence[str], n_gram: int) -> Counter:
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
"""A function get n_gram of sentences.

Args:
token (Sequence[str]): A series of tokens about sentences.
n_gram (int):The maximum number of words contained in a phrase
when calculating word fragments. Defaults to 4.
ice-tong marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Counter: The n_gram contained in sentences with Counter format.
"""
counter: Counter = Counter()
for i in range(1, n_gram + 1):
for j in range(len(token) - i + 1):
key = tuple(token[j:(i + j)])
counter[key] += 1
return counter


def tokenizer_fn(sentence: str) -> Sequence[str]:
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
"""This function is used to segment a sentence.

Args:
sentence(str): A sentence.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Sequence[str]:A Sequence of tokens after word segmentation.
ice-tong marked this conversation as resolved.
Show resolved Hide resolved
"""
return sentence.split()


def _get_brevity_penalty(pred_len: np.array,
references_len: np.array) -> np.array:
"""This function is used to calculate penalty factor.

Args:
pred_len(np.array): number of grams in the predicted sentence.
references_len(np.array): number of grams in the references.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

Returns:
np.array: penalty factor.
"""
if pred_len > references_len:
return np.array(1.)
return np.array(np.exp(1 - references_len / pred_len))


class Bleu(BaseMetric):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
"""Bilingual Evaluation Understudy metric.

This metric is a tool for evaluating the quality of machine translation.
The closer the translation is to human translation,
the higher the score will be.

Args:
n_gram (int):The maximum number of words contained in a phrase
ice-tong marked this conversation as resolved.
Show resolved Hide resolved
when calculating word fragments. Defaults to 4.
smooth(bool): Whether or not to apply smoothing. Default to False.
ngram_weights(optional | Sequence[float]): Weights used
for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used.Default to None.
ice-tong marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
ice-tong marked this conversation as resolved.
Show resolved Hide resolved

Examples:

>>> predictions = ['the cat is on the mat','There is a big tree near the park here'] # type: ignore # noqa: E501
ice-tong marked this conversation as resolved.
Show resolved Hide resolved
>>> references = [['a cat is on the mat'],['A big tree is growing near the park here']] # type: ignore # noqa: E501
>>> bleu = Bleu()
>>> bleu_results = bleu(predictions, references)
{'bleu': ...}

Calculate Bleu with smooth:

>>> predictions = ['the cat is on the mat','There is a big tree near the park here'] # type: ignore # noqa: E501
>>> references = [['a cat is on the mat'],['A big tree is growing near the park here']] # type: ignore # noqa: E501
>>> bleu = Bleu(smooth = True)
>>> bleu_results = bleu(predictions, references)
{'bleu': ...}
"""

def __init__(self,
n_gram: int = 4,
smooth: bool = False,
ngram_weights: Optional[Sequence[float]] = None,
**kwargs) -> None:
super().__init__(**kwargs)
self.n_gram = n_gram
self.smooth = smooth
if ngram_weights is not None and len(ngram_weights) != n_gram:
raise ValueError(
f'List of weights has different weights than `n_gram`: '
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
f'{len(ngram_weights)} != {n_gram}')
self.ngram_weights = ngram_weights if ngram_weights is not None else [
1.0 / n_gram
] * n_gram
ice-tong marked this conversation as resolved.
Show resolved Hide resolved

def add(self, predictions: Sequence[str], references: Sequence[Sequence[str]]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Add the intermediate results to ``self._results``.

Args:
predictions (Sequence[str]): An iterable of machine
translated corpus.
references (Sequence[Sequence[str]]): An iterable of
iterables of reference corpus.
"""

references_token: Sequence[Sequence[Sequence[str]]] = [[
tokenizer_fn(line) if line is not None else [] for line in r
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
] for r in references]
predictions_token: Sequence[Sequence[str]] = [
tokenizer_fn(line) if line else [] for line in predictions
]
pred_len = 0
references_len = 0
ice-tong marked this conversation as resolved.
Show resolved Hide resolved
precision_matches = np.zeros(self.n_gram)
precision_total = np.zeros(self.n_gram)
for prediction, references in zip(predictions_token, references_token):
pred_len += len(prediction)
references_len_list = [len(reference) for reference in references]
references_len_diff = [
abs(len(prediction) - length) for length in references_len_list
]
# In the multi sentence reference, the one whose length is closest
# to the predicted sentence is selected to record the length.
min_index = references_len_diff.index(min(references_len_diff))
references_len += references_len_list[min_index]

pred_counter: Counter = get_n_gram(prediction, self.n_gram)
reference_counter: Counter = Counter()
for reference in references:
# Take intersection for the n_gram of references.
reference_counter |= get_n_gram(reference, self.n_gram)
# Union the n_gram of prediction and references.
counter_clip = pred_counter & reference_counter
ice-tong marked this conversation as resolved.
Show resolved Hide resolved

for counter in counter_clip:
precision_matches[len(counter) - 1] += counter_clip[counter]
for counter in pred_counter:
precision_total[len(counter) - 1] += pred_counter[counter]
result = (pred_len, references_len, precision_matches, precision_total)
self._results.append(result)
ice-tong marked this conversation as resolved.
Show resolved Hide resolved

def compute_metric(
self, results: List[Tuple[int, int, np.ndarray,
np.ndarray]]) -> dict:
"""Compute the bleu metric.

This method would be invoked in ``BaseMetric.compute`` after
distributed synchronization.

Args:
results (List[Tuple[int, int, np.ndarray, np.ndarray]]):
A list that consisting the tuple of correct numbers.
Tuple contains pred_len, references_len,
precision_matches, precision_total.
This list has already been synced across all ranks.

Returns:
Dict[str, float]: The computed bleu score.
"""
pred_len = 0
references_len = 0
precision_matches = np.zeros(self.n_gram)
precision_total = np.zeros(self.n_gram)
for result in results:
pred_len += result[0]
references_len += result[1]
precision_matches += result[2]
precision_total += result[3]

if min(precision_matches) == 0.0:
return {'bleu': np.array(0.0)}
ice-tong marked this conversation as resolved.
Show resolved Hide resolved
if self.smooth:
ice-tong marked this conversation as resolved.
Show resolved Hide resolved
precision_score = np.add(precision_matches, np.ones(
self.n_gram)) / np.add(precision_total, np.ones(self.n_gram))
precision_score[0] = precision_matches[0] / precision_total[0]
else:
precision_score = precision_matches / precision_total

precision_score = np.array(
self.ngram_weights) * np.log(precision_score)
brevity_penalty = _get_brevity_penalty(pred_len, references_len)
bleu = brevity_penalty * np.exp(np.sum(precision_score))
result = {'bleu': round(float(bleu), 6)}
ice-tong marked this conversation as resolved.
Show resolved Hide resolved
return result
6 changes: 3 additions & 3 deletions mmeval/metrics/end_point_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ def end_point_error_map(
return epe.mean().cpu().numpy(), int(val.sum())

@dispatch # noqa: F811
def end_point_error_map(
def end_point_error_map( # noqa: F811
self,
prediction: 'oneflow.Tensor',
label: 'oneflow.Tensor',
valid_mask: Optional['oneflow.Tensor'] = None
) -> Tuple[np.ndarray, int]:
valid_mask: Optional['oneflow.Tensor'] = None) -> Tuple[np.ndarray,
int]:
ice-tong marked this conversation as resolved.
Show resolved Hide resolved
"""Calculate end point error map.

Args:
Expand Down
52 changes: 52 additions & 0 deletions tests/test_metrics/test_bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest

from mmeval.metrics import Bleu


def test_bleu():
predictions = [
'the cat is on the mat',
'There is a big tree near the park here',
'The sun rises from the northeast with sunshine',
'I was late for work today for the rainy',
'My name is Barry',
]
references = [['a cat is on the mat'],
['A big tree is growing near the park here'],
['A fierce sun rises in the northeast with sunshine'],
['I went to work too late today for the rainy'],
['I am Barry']]

bleu = Bleu()
for i in range(len(predictions)):
bleu.add([predictions[i]], [references[i]])
bleu_results = bleu.compute()
assert isinstance(bleu_results, dict)
np.testing.assert_almost_equal(bleu_results['bleu'], 0.400674)

bleu = Bleu(smooth=True)
bleu_results = bleu(predictions, references)
assert isinstance(bleu_results, dict)
np.testing.assert_almost_equal(bleu_results['bleu'], 0.425048)


@pytest.mark.parametrize('n_gram', [1, 2, 3, 4])
def test_bleu_ngram(n_gram):
predictions = [
'the cat is on the mat',
'There is a big tree near the park here',
'The sun rises from the northeast with sunshine',
'I was late for work today for the rainy',
'My name is Barry',
]
references = [['a cat is on the mat'],
['A big tree is growing near the park here'],
['A fierce sun rises in the northeast with sunshine'],
['I went to work too late today for the rainy'],
['I am Barry']]

bleu = Bleu(n_gram=n_gram)
bleu_results = bleu(predictions, references)
assert isinstance(bleu_results, dict)