Skip to content

Commit

Permalink
specification details
Browse files Browse the repository at this point in the history
  • Loading branch information
go-with-me000 committed Jan 16, 2023
1 parent 5d3cacc commit 107cb9f
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 78 deletions.
3 changes: 2 additions & 1 deletion mmeval/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class BLEU(BaseMetric):
ngram_weights (Sequence[float], optional): Weights used
for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used. Defaults to None.
tokenizer_fn (Union[Callable, str, None]): A user's own tokenizer function.
tokenizer_fn (Callable or str, optional): A user's own tokenizer function.
Defaults to None.
New in version 0.3.0.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
Examples:
Expand Down
32 changes: 16 additions & 16 deletions mmeval/metrics/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def _get_results(matches: int, pred_len: int,
precision = matches / pred_len
recall = matches / reference_len
if precision == recall == 0.0:
return dict(
precision=float(0.0), recall=float(0.0), fmeasure=float(0.0))
return dict(precision=0., recall=0., fmeasure=0.)

fmeasure = 2 * precision * recall / (precision + recall)
return dict(
Expand All @@ -55,9 +54,8 @@ def _rougeL_score(pred: Sequence[str],
Dict[str, float]: Calculate the score of rougeL.
"""
pred_len, reference_len = len(pred), len(reference)
if 0 in (pred_len, reference_len):
return dict(
precision=float(0.0), recall=float(0.0), fmeasure=float(0.0))
if pred_len == 0 or reference_len == 0:
return dict(precision=0., recall=0., fmeasure=0.)
lcs = 0
matches = SequenceMatcher(None, pred, reference).get_matching_blocks()
for match in matches:
Expand All @@ -72,25 +70,27 @@ def _rougeN_score(pred: Sequence[str], reference: Sequence[str],
Args:
pred (Sequence[str]): A predicted sentence.
reference (Sequence[str]): A referenced sentence.
n_gram (int): The number of words contained in a phrase
when calculating word fragments.
Returns:
Dict[str, float]: Calculate the score of rougeN.
"""

def _create_ngrams(tokens: Sequence[str], n: int) -> Counter:
ngrams: Counter = Counter()
for ngram in (tuple(tokens[i:i + n])
for i in range(len(tokens) - n + 1)):
for i in range(len(tokens) - n + 1):
ngram = tuple(tokens[i:i + n])
ngrams[ngram] += 1
print(ngrams)
return ngrams

pred_ngarms = _create_ngrams(pred, n_gram)
reference_ngarms = _create_ngrams(reference, n_gram)
pred_len = sum(pred_ngarms.values())
reference_len = sum(reference_ngarms.values())
if 0 in (pred_len, reference_len):
return dict(
precision=float(0.0), recall=float(0.0), fmeasure=float(0.0))
if pred_len == 0 or reference_len == 0:
return dict(precision=0., recall=0., fmeasure=0.)

# Take the intersection of n_gram of prediction and reference.
hits = sum(
Expand All @@ -107,18 +107,18 @@ class ROUGE(BaseMetric):
automatic summarization, question and answer generation, etc.
Args:
rouge_keys (Uinon[List, Tule, int, str]): A list of rouge types to calculate.
rouge_keys (List or Tuple or int or str): A list of rouge types to calculate.
Keys that are allowed are ``L``, and ``1`` through ``9``.
Defaults to ``(1, 2, 'L')``.
use_stemmer (bool): Use Porter stemmer to strip word
suffixes to improve matching. Defaults to False.
normalizer (Callable, optional): A user's own normalizer function.
If this is ``None``, replacing any non-alpha-numeric characters
with spaces is default. Defaults to None.
tokenizer (Union[Callable, str, None]): A user's own tokenizer function.
tokenizer (Callable or str, optional): A user's own tokenizer function.
Defaults to None.
accumulate (str): Useful in case of multi-reference rouge score.
``avg`` takes the avg of all references with respect to predictions
``avg`` takes the average of all references with respect to predictions
``best`` takes the best fmeasure score obtained between prediction
and multiple corresponding references.
Defaults to ``best``.
Expand Down Expand Up @@ -190,7 +190,7 @@ def add(self, predictions: Sequence[str], references: Sequence[Sequence[str]]) -
Args:
predictions (Sequence[str]): An iterable of predicted sentences.
references (Sequence[Sequence[str]): An iterable of
referenced sentences.
referenced sentences.
"""
# If the tokenizer is None, check the first sentence
# to determine which language the tokenizer is used.
Expand Down Expand Up @@ -241,8 +241,7 @@ def compute_metric(self, results: List[Any]) -> dict:
distributed synchronization.
Args:
results (List):
A list that consisting the list of correct numbers.
results (List): A list that consists of the list of correct numbers.
This list has already been synced across all ranks.
Returns:
Expand Down Expand Up @@ -274,6 +273,7 @@ def _normalize_and_tokenize(self, text: str) -> Sequence[str]:
Args:
text (str): An input sentence.
Returns:
Sequence[str]: The tokens after normalizer and tokenizer.
"""
Expand Down
3 changes: 2 additions & 1 deletion mmeval/metrics/utils/ngram_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_n_gram(token: Sequence[str], n_gram: int) -> Counter:
Args:
token (Sequence[str]): A series of tokens about sentences.
n_gram (int): The maximum number of words contained in a phrase
when calculating word fragments. Defaults to 4.
when calculating word fragments.
Returns:
Counter: The n_gram contained in sentences with Counter format.
Expand All @@ -25,6 +25,7 @@ def get_n_gram(token: Sequence[str], n_gram: int) -> Counter:

def inference_language(text: str) -> str:
"""Determine the type of language.
Args:
text (str): Input for language judgment.
Expand Down
64 changes: 7 additions & 57 deletions tests/test_metrics/test_bleu.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
from collections import Counter

from mmeval.metrics import BLEU
from mmeval.metrics.bleu import get_n_gram


@pytest.mark.parametrize('n_gram', [2, 4])
def test_get_n_gram(n_gram):
token = ['a', 'cat', 'is', 'on', 'the', 'mat']
result = get_n_gram(token, n_gram)
if n_gram == 2:
counter = Counter({
('a', ): 1,
('cat', ): 1,
('is', ): 1,
('on', ): 1,
('the', ): 1,
('mat', ): 1,
('a', 'cat'): 1,
('cat', 'is'): 1,
('is', 'on'): 1,
('on', 'the'): 1,
('the', 'mat'): 1
})
else:
counter = Counter({
('a', ): 1,
('cat', ): 1,
('is', ): 1,
('on', ): 1,
('the', ): 1,
('mat', ): 1,
('a', 'cat'): 1,
('cat', 'is'): 1,
('is', 'on'): 1,
('on', 'the'): 1,
('the', 'mat'): 1,
('a', 'cat', 'is'): 1,
('cat', 'is', 'on'): 1,
('is', 'on', 'the'): 1,
('on', 'the', 'mat'): 1,
('a', 'cat', 'is', 'on'): 1,
('cat', 'is', 'on', 'the'): 1,
('is', 'on', 'the', 'mat'): 1
})
assert result == counter


def test_bleu():
Expand Down Expand Up @@ -75,6 +31,13 @@ def test_bleu():
assert isinstance(bleu_results, dict)
np.testing.assert_almost_equal(bleu_results['bleu'], 0.4250477)

predictions = ['猫坐在垫子上', '公园旁边有棵树']
references = [['猫在那边的垫子'], ['一棵树长在公园旁边']]
metric = BLEU()
metric.add(predictions, references)
bleu_results = metric.compute()
np.testing.assert_almost_equal(bleu_results['bleu'], 0.2576968)


@pytest.mark.parametrize('n_gram', [1, 2, 3, 4])
def test_bleu_ngram(n_gram):
Expand All @@ -94,16 +57,3 @@ def test_bleu_ngram(n_gram):
bleu = BLEU(n_gram=n_gram)
bleu_results = bleu(predictions, references)
assert isinstance(bleu_results, dict)


@pytest.mark.parametrize('dataset', [
{
'predictions': ['猫坐在垫子上', '公园旁边有棵树'],
'references': [['猫在那边的垫子'], ['一棵树长在公园旁边']]
},
])
def test_input(dataset):
metric = BLEU()
metric.add(dataset['predictions'], dataset['references'])
result = metric.compute()
print(result)
8 changes: 5 additions & 3 deletions tests/test_metrics/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
'references': [['猫在那边的垫子'], ['一棵树长在公园旁边']]
},
])
def test_input(dataset):
def test_chinese(dataset):
metric = ROUGE()
metric.add(dataset['predictions'], dataset['references'])
result = metric.compute()
print(result)
results = metric.compute()
print(results)
np.testing.assert_almost_equal(results['rouge2_fmeasure'], 0.3766233)
np.testing.assert_almost_equal(results['rougeL_fmeasure'], 0.5576923)


@pytest.mark.parametrize('rouge_keys', [2, 'L', [2, 'L']])
Expand Down
47 changes: 47 additions & 0 deletions tests/test_metrics/test_utils/test_ngram_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
from collections import Counter

from mmeval.metrics.bleu import get_n_gram


@pytest.mark.parametrize('n_gram', [2, 4])
def test_get_n_gram(n_gram):
token = ['a', 'cat', 'is', 'on', 'the', 'mat']
result = get_n_gram(token, n_gram)
if n_gram == 2:
counter = Counter({
('a', ): 1,
('cat', ): 1,
('is', ): 1,
('on', ): 1,
('the', ): 1,
('mat', ): 1,
('a', 'cat'): 1,
('cat', 'is'): 1,
('is', 'on'): 1,
('on', 'the'): 1,
('the', 'mat'): 1
})
else:
counter = Counter({
('a', ): 1,
('cat', ): 1,
('is', ): 1,
('on', ): 1,
('the', ): 1,
('mat', ): 1,
('a', 'cat'): 1,
('cat', 'is'): 1,
('is', 'on'): 1,
('on', 'the'): 1,
('the', 'mat'): 1,
('a', 'cat', 'is'): 1,
('cat', 'is', 'on'): 1,
('is', 'on', 'the'): 1,
('on', 'the', 'mat'): 1,
('a', 'cat', 'is', 'on'): 1,
('cat', 'is', 'on', 'the'): 1,
('is', 'on', 'the', 'mat'): 1
})
assert result == counter

0 comments on commit 107cb9f

Please sign in to comment.