forked from Lightning-AI/pytorch-lightning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
metrics: add BLEU (Lightning-AI#2535)
* metrics: added bleu score and test bleu * metrics: fixed type hints in bleu * bleu score moved to metrics/functional/nlp.py * refactor with torch.Tensor * Update test_sequence.py * refactor as Borda requests and nltk==3.2 * locked nltk==3.3 * nltk>=3.3, parametrized smooth argument for test * fix bleu_score example * added class BLEUScore metrics and test * added class BLEUScore metrics and test * update CHANGELOG * refactor with torchtext * torchtext changed to optional import * fix E501 line too long * add else: in optional import * remove pragma: no-cover * constants changed to CAPITALS * remove class in tests * List -> Sequence, conda -> pip, cast with tensor * add torchtext in test.txt * remove torchtext from test.txt * bump torchtext to 0.5.0 * bump torchtext to 0.5.0 * Apply suggestions from code review * ignore bleu score in doctest, renamed to nlp.py * back to implementation with torch * remove --ignore in CI test, proper reference format * apply justus comment Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
- Loading branch information
Showing
11 changed files
with
287 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ dependencies: | |
- twine==1.13.0 | ||
- pillow<7.0.0 | ||
- scikit-image | ||
- nltk>=3.3 | ||
|
||
# Optional | ||
- scipy>=0.13.3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,5 +25,6 @@ | |
mse, | ||
psnr, | ||
rmse, | ||
rmsle | ||
rmsle, | ||
) | ||
from pytorch_lightning.metrics.functional.nlp import bleu_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# referenced from | ||
# Library Name: torchtext | ||
# Authors: torchtext authors and @sluks | ||
# Date: 2020-07-18 | ||
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score | ||
from typing import Sequence, List | ||
from collections import Counter | ||
|
||
import torch | ||
|
||
|
||
def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: | ||
"""Counting how many times each word appears in a given text with ngram | ||
Args: | ||
ngram_input_list: A list of translated text or reference texts | ||
n_gram: gram value ranged 1 to 4 | ||
Return: | ||
ngram_counter: a collections.Counter object of ngram | ||
""" | ||
|
||
ngram_counter = Counter() | ||
|
||
for i in range(1, n_gram + 1): | ||
for j in range(len(ngram_input_list) - i + 1): | ||
ngram_key = tuple(ngram_input_list[j : i + j]) | ||
ngram_counter[ngram_key] += 1 | ||
|
||
return ngram_counter | ||
|
||
|
||
def bleu_score( | ||
translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False | ||
) -> torch.Tensor: | ||
"""Calculate BLEU score of machine translated text with one or more references. | ||
Args: | ||
translate_corpus: An iterable of machine translated corpus | ||
reference_corpus: An iterable of iterables of reference corpus | ||
n_gram: Gram value ranged from 1 to 4 (Default 4) | ||
smooth: Whether or not to apply smoothing – Lin et al. 2004 | ||
Return: | ||
A Tensor with BLEU Score | ||
Example: | ||
>>> translate_corpus = ['the cat is on the mat'.split()] | ||
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] | ||
>>> bleu_score(translate_corpus, reference_corpus) | ||
tensor(0.7598) | ||
""" | ||
|
||
assert len(translate_corpus) == len(reference_corpus) | ||
numerator = torch.zeros(n_gram) | ||
denominator = torch.zeros(n_gram) | ||
precision_scores = torch.zeros(n_gram) | ||
c = 0.0 | ||
r = 0.0 | ||
for (translation, references) in zip(translate_corpus, reference_corpus): | ||
c += len(translation) | ||
ref_len_list = [len(ref) for ref in references] | ||
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] | ||
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] | ||
translation_counter = _count_ngram(translation, n_gram) | ||
reference_counter = Counter() | ||
for ref in references: | ||
reference_counter |= _count_ngram(ref, n_gram) | ||
|
||
ngram_counter_clip = translation_counter & reference_counter | ||
for counter_clip in ngram_counter_clip: | ||
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] | ||
|
||
for counter in translation_counter: | ||
denominator[len(counter) - 1] += translation_counter[counter] | ||
|
||
trans_len = torch.tensor(c) | ||
ref_len = torch.tensor(r) | ||
if min(numerator) == 0.0: | ||
return torch.tensor(0.0) | ||
|
||
if smooth: | ||
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) | ||
else: | ||
precision_scores = numerator / denominator | ||
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) | ||
geometric_mean = torch.exp(torch.sum(log_precision_scores)) | ||
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) | ||
bleu = brevity_penalty * geometric_mean | ||
|
||
return bleu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import torch | ||
|
||
from pytorch_lightning.metrics.functional.nlp import bleu_score | ||
from pytorch_lightning.metrics.metric import Metric | ||
|
||
|
||
class BLEUScore(Metric): | ||
""" | ||
Calculate BLEU score of machine translated text with one or more references. | ||
Example: | ||
>>> translate_corpus = ['the cat is on the mat'.split()] | ||
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] | ||
>>> metric = BLEUScore() | ||
>>> metric(translate_corpus, reference_corpus) | ||
tensor(0.7598) | ||
""" | ||
|
||
def __init__(self, n_gram: int = 4, smooth: bool = False): | ||
""" | ||
Args: | ||
n_gram: Gram value ranged from 1 to 4 (Default 4) | ||
smooth: Whether or not to apply smoothing – Lin et al. 2004 | ||
""" | ||
super().__init__(name="bleu") | ||
self.n_gram = n_gram | ||
self.smooth = smooth | ||
|
||
def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor: | ||
""" | ||
Actual metric computation | ||
Args: | ||
translate_corpus: An iterable of machine translated corpus | ||
reference_corpus: An iterable of iterables of reference corpus | ||
Return: | ||
torch.Tensor: BLEU Score | ||
""" | ||
return bleu_score( | ||
translate_corpus=translate_corpus, | ||
reference_corpus=reference_corpus, | ||
n_gram=self.n_gram, | ||
smooth=self.smooth, | ||
).to(self.device, self.dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,4 @@ horovod>=0.19.1 | |
omegaconf>=2.0.0 | ||
# scipy>=0.13.3 | ||
scikit-learn>=0.20.0 | ||
torchtext>=0.3.1 | ||
torchtext>=0.3.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,4 @@ black==19.10b0 | |
pre-commit>=1.0 | ||
|
||
cloudpickle>=1.2 | ||
nltk>=3.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import pytest | ||
import torch | ||
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu | ||
|
||
from pytorch_lightning.metrics.functional.nlp import bleu_score | ||
|
||
# example taken from | ||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu | ||
HYPOTHESIS1 = tuple( | ||
"It is a guide to action which ensures that the military always obeys the commands of the party".split() | ||
) | ||
REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split()) | ||
REFERENCE2 = tuple( | ||
"It is a guiding principle which makes the military forces always being under the command of the Party".split() | ||
) | ||
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split()) | ||
|
||
|
||
# example taken from | ||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu | ||
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() | ||
HYP2 = "he read the book because he was interested in world history".split() | ||
|
||
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() | ||
REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split() | ||
REF1C = "It is the practical guide for the army always to heed the directions of the party".split() | ||
REF2A = "he was interested in world history because he read the book".split() | ||
|
||
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] | ||
HYPOTHESES = [HYP1, HYP2] | ||
|
||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction | ||
smooth_func = SmoothingFunction().method2 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["weights", "n_gram", "smooth_func", "smooth"], | ||
[ | ||
pytest.param([1], 1, None, False), | ||
pytest.param([0.5, 0.5], 2, smooth_func, True), | ||
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False), | ||
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True), | ||
], | ||
) | ||
def test_bleu_score(weights, n_gram, smooth_func, smooth): | ||
nltk_output = sentence_bleu( | ||
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func | ||
) | ||
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) | ||
assert torch.allclose(pl_output, torch.tensor(nltk_output)) | ||
|
||
nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) | ||
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) | ||
assert torch.allclose(pl_output, torch.tensor(nltk_output)) | ||
|
||
|
||
def test_bleu_empty(): | ||
hyp = [[]] | ||
ref = [[[]]] | ||
assert bleu_score(hyp, ref) == torch.tensor(0.0) | ||
|
||
|
||
def test_no_4_gram(): | ||
hyps = [["My", "full", "pytorch-lightning"]] | ||
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] | ||
assert bleu_score(hyps, refs) == torch.tensor(0.0) |
Oops, something went wrong.