Skip to content

Commit

Permalink
Merge pull request #4 from naist-nlp/unified-metric
Browse files Browse the repository at this point in the history
Support multiple scoring methods of UnifiedMetric
  • Loading branch information
de9uch1 authored Jul 8, 2024
2 parents 4448dfc + 9309919 commit fe32c95
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 49 deletions.
11 changes: 10 additions & 1 deletion mbrs/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from mbrs.metrics import MetricCOMET, MetricCOMETQE
from mbrs.metrics import MetricCOMET, MetricCOMETQE, MetricXCOMET


@pytest.fixture(scope="session")
Expand All @@ -11,3 +12,11 @@ def metric_comet():
@pytest.fixture(scope="session")
def metric_cometqe():
return MetricCOMETQE(MetricCOMETQE.Config())


@pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA is not available on this machine."
)
@pytest.fixture(scope="session")
def metric_xcomet():
return MetricXCOMET(MetricXCOMET.Config())
2 changes: 1 addition & 1 deletion mbrs/decoders/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def decode(
DecoderRerank.Output: The n-best hypotheses.
"""
with timer.measure("rerank"):
scores = self.metric.scores(hypotheses, source)
scores = self.metric.scores(hypotheses, source=source)
topk_scores, topk_indices = self.metric.topk(scores, k=nbest)
return self.Output(
idx=topk_indices,
Expand Down
14 changes: 7 additions & 7 deletions mbrs/metrics/cometqe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch
from comet import download_model, load_from_checkpoint
from transformers import BatchEncoding

from mbrs import utils

from . import MetricReferenceless, register

Expand Down Expand Up @@ -38,11 +39,11 @@ def __init__(self, cfg: MetricCOMETQE.Config):
param.requires_grad = False

if not cfg.cpu and torch.cuda.is_available():
self.scorer = self.scorer.cuda()
if cfg.fp16:
self.scorer = self.scorer.half()
elif cfg.bf16:
self.scorer = self.scorer.bfloat16()
self.scorer = self.scorer.cuda()

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -74,9 +75,8 @@ def scores(self, hypotheses: list[str], source: str) -> torch.Tensor:
data = [{"src": source, "mt": hyp} for hyp in hypotheses]
scores = []
for i in range(0, len(data), self.cfg.batch_size):
batch = BatchEncoding(
self.scorer.prepare_for_inference(data[i : i + self.cfg.batch_size])[0]
).to(self.scorer.device)
model_output = self.scorer.predict_step((batch,))
batch = self.scorer.prepare_for_inference(data[i : i + self.cfg.batch_size])
batch = utils.to_device(batch, self.device)
model_output = self.scorer.predict_step(batch)
scores.append(model_output.scores)
return torch.cat(scores)
return torch.cat(scores).view(len(hypotheses))
53 changes: 23 additions & 30 deletions mbrs/metrics/xcomet.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,18 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Optional
from typing import Optional

import torch
from comet import download_model, load_from_checkpoint
from comet.models import XCOMETMetric
from torch import Tensor

from mbrs import timer
from mbrs import timer, utils

from . import Metric, register


def to_device(sample: Any, device: torch.device):
def _to_device(x):
if torch.is_tensor(x):
return x.to(device=device, non_blocking=True)
elif isinstance(x, dict):
return {key: _to_device(value) for key, value in x.items()}
elif isinstance(x, list):
return [_to_device(x) for x in x]
elif isinstance(x, tuple):
return tuple(_to_device(x) for x in x)
elif isinstance(x, set):
return {_to_device(x) for x in x}
else:
return x

return _to_device(sample)


@register("xcomet")
class MetricXCOMET(Metric):
"""XCOMET metric class."""
Expand Down Expand Up @@ -74,41 +56,52 @@ def device(self) -> torch.device:
return self.scorer.device

def score(
self, hypothesis: str, reference: str, source: Optional[str] = None
self,
hypothesis: str,
reference: Optional[str] = None,
source: Optional[str] = None,
) -> float:
"""Calculate the score of the given hypothesis.
Args:
hypothesis (str): A hypothesis.
reference (str): A reference.
reference (str, optional): A reference.
source (str, optional): A source.
Returns:
float: The score of the given hypothesis.
"""
inputs = {"mt": hypothesis, "ref": reference}
inputs = {"mt": hypothesis}
if reference is not None:
inputs["ref"] = reference
if source is not None:
inputs["src"] = source

batch = self.scorer.prepare_for_inference([inputs])
batch = to_device(batch, self.device)
batch = utils.to_device(batch, self.device)
model_output = self.scorer.predict_step(batch)
return model_output.scores.item()

def scores(
self, hypotheses: list[str], references: list[str], source: Optional[str] = None
self,
hypotheses: list[str],
references: Optional[list[str]] = None,
source: Optional[str] = None,
) -> Tensor:
"""Calculate the scores of the given hypothesis.
Args:
hypotheses (str): N hypotheses.
references (str): N references.
hypotheses (list[str]): N hypotheses.
references (list[str], optional): N references.
source (str, optional): A source.
Returns:
Tensor: The N scores of the given hypotheses.
"""
inputs = [{"mt": hyp, "ref": ref} for hyp, ref in zip(hypotheses, references)]
inputs = [{"mt": hyp} for hyp in hypotheses]
if references is not None:
for d, ref in zip(inputs, references):
d["ref"] = ref
if source is not None:
for d in inputs:
d["src"] = source
Expand All @@ -120,7 +113,7 @@ def scores(
batch = self.scorer.prepare_for_inference(
inputs[i : i + self.cfg.batch_size]
)
batch = to_device(batch, self.device)
batch = utils.to_device(batch, self.device)
model_output = self.scorer.predict_step(batch)
scores.append(model_output.scores)
return torch.cat(scores).view(len(hypotheses))
Expand Down Expand Up @@ -151,7 +144,7 @@ def pairwise_scores(
batch = self.scorer.prepare_for_inference(
data[i : i + self.cfg.batch_size]
)
batch = to_device(batch, self.device)
batch = utils.to_device(batch, self.device)
model_output = self.scorer.predict_step(batch)
scores.append(model_output.scores)
return torch.cat(scores).view(len(hypotheses), len(references))
75 changes: 65 additions & 10 deletions mbrs/metrics/xcomet_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,72 @@
import pytest
import torch

from .xcomet import to_device
from .xcomet import MetricXCOMET

SOURCE = "これはテストです"
HYPOTHESES = [
"this is a test",
"another test",
"this is a fest",
"Producția de zahăr primă va fi exprimată în ceea ce privește zahărul alb;",
]
REFERENCES = [
"ref",
"this is a test",
"producţia de zahăr brut se exprimă în zahăr alb;",
]
SCORES = torch.Tensor(
[
[0.97671, 1.00000, 0.49054],
[0.94399, 0.99120, 0.43007],
[0.71786, 0.71210, 0.30775],
[0.21788, 0.22079, 0.61004],
]
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available on this machine.")
def test_to_device():
device = torch.device("cuda:0")

for x in [1, 1.0, "a", True]:
assert to_device(x, device) == x
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA is not available on this machine."
)
class TestMetricXCOMET:
def test_score(self, metric_xcomet: MetricXCOMET):
for i, hyp in enumerate(HYPOTHESES):
for j, ref in enumerate(REFERENCES):
assert torch.isclose(
SCORES[i, j],
torch.tensor(metric_xcomet.score(hyp, ref, SOURCE)),
atol=0.0005 / 100,
)

assert to_device(torch.ones(1), device).device == device
assert to_device({"a": torch.ones(1)}, device)["a"].device == device
assert to_device([torch.ones(1)], device)[0].device == device
assert to_device((torch.ones(1),), device)[0].device == device
def test_scores(self, metric_xcomet: MetricXCOMET):
hyps = ["another test", "this is a test", "this is an test"]
refs = ["another test", "this is a fest", "this is a test"]
src = SOURCE

torch.testing.assert_close(
metric_xcomet.scores(hyps, refs, src).cpu().float(),
torch.FloatTensor([1.00000, 0.90545, 1.00000]),
atol=0.0005 / 100,
rtol=1e-6,
)
torch.testing.assert_close(
metric_xcomet.scores(hyps, source=src).cpu().float(),
torch.FloatTensor([0.99120, 0.99120, 0.99120]),
atol=0.0005 / 100,
rtol=1e-6,
)
torch.testing.assert_close(
metric_xcomet.scores(hyps, references=refs).cpu().float(),
torch.FloatTensor([1.00000, 0.77420, 1.00000]),
atol=0.0005 / 100,
rtol=1e-6,
)

def test_expected_scores(self, metric_xcomet: MetricXCOMET):
expected_scores = metric_xcomet.expected_scores(HYPOTHESES, REFERENCES, SOURCE)
torch.testing.assert_close(
expected_scores,
SCORES.mean(dim=1).to(metric_xcomet.device),
atol=0.0005 / 100,
rtol=1e-6,
)
21 changes: 21 additions & 0 deletions mbrs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any

import torch


def to_device(sample: Any, device: torch.device):
def _to_device(x):
if torch.is_tensor(x):
return x.to(device=device, non_blocking=True)
elif isinstance(x, dict):
return {key: _to_device(value) for key, value in x.items()}
elif isinstance(x, list):
return [_to_device(x) for x in x]
elif isinstance(x, tuple):
return tuple(_to_device(x) for x in x)
elif isinstance(x, set):
return {_to_device(x) for x in x}
else:
return x

return _to_device(sample)
19 changes: 19 additions & 0 deletions mbrs/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
import torch

from . import utils


@pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA is not available on this machine."
)
def test_to_device():
device = torch.device("cuda:0")

for x in [1, 1.0, "a", True]:
assert utils.to_device(x, device) == x

assert utils.to_device(torch.ones(1), device).device == device
assert utils.to_device({"a": torch.ones(1)}, device)["a"].device == device
assert utils.to_device([torch.ones(1)], device)[0].device == device
assert utils.to_device((torch.ones(1),), device)[0].device == device

0 comments on commit fe32c95

Please sign in to comment.