Skip to content

Commit

Permalink
Merge branch 'master' into req/mecab-ko
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Apr 19, 2024
2 parents e1efdbc + d03ca5e commit dfc21e0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
12 changes: 5 additions & 7 deletions src/torchmetrics/functional/text/chrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _get_total_ngrams(n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]])
"""Get total sum of n-grams over n-grams w.r.t n."""
total_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0))
for n in n_grams_counts:
total_n_grams[n] = tensor(sum(n_grams_counts[n].values()))
total_n_grams[n] = sum(n_grams_counts[n].values()).detach().clone() # type: ignore
return total_n_grams

char_n_grams_counts, word_n_grams_counts = _char_and_word_ngrams_counts(
Expand Down Expand Up @@ -216,12 +216,10 @@ def _get_ngram_matches(
"""
matching_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0))
for n in hyp_n_grams_counts:
matching_n_grams[n] = tensor(
sum(
torch.min(ref_n_grams_counts[n][n_gram], hyp_n_grams_counts[n][n_gram])
for n_gram in hyp_n_grams_counts[n]
)
)
min_n_grams = [
torch.min(ref_n_grams_counts[n][n_gram], hyp_n_grams_counts[n][n_gram]) for n_gram in hyp_n_grams_counts[n]
]
matching_n_grams[n] = sum(min_n_grams).detach().clone() # type: ignore
return matching_n_grams


Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/text/test_sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def test_bleu_score_class(self, ddp, preds, targets, tokenize, lowercase):
"""Test class implementation of metric."""
if _should_skip_tokenizer(tokenize):
pytest.skip(reason="`ko-mecab` tokenizer requires `mecab-ko` package to be installed")
if tokenize == "flores200":
pytest.skip("flores200 tests are flaky") # TODO: figure out why

metric_args = {"tokenize": tokenize, "lowercase": lowercase}
original_sacrebleu = partial(_reference_sacre_bleu, tokenize=tokenize, lowercase=lowercase)
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/text/test_ter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TestTER(TextTester):
"""Test class for `TranslationEditRate` metric."""

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_chrf_score_class(self, ddp, preds, targets, normalize, no_punctuation, asian_support, lowercase):
def test_ter_class(self, ddp, preds, targets, normalize, no_punctuation, asian_support, lowercase):
"""Test class implementation of metric."""
metric_args = {
"normalize": normalize,
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, a
metric_args=metric_args,
)

def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctuation, asian_support, lowercase):
def test_ter_differentiability(self, preds, targets, normalize, no_punctuation, asian_support, lowercase):
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
metric_args = {
"normalize": normalize,
Expand Down

0 comments on commit dfc21e0

Please sign in to comment.