Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple scoring methods of UnifiedMetric #4

Merged
merged 2 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion mbrs/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from mbrs.metrics import MetricCOMET, MetricCOMETQE
from mbrs.metrics import MetricCOMET, MetricCOMETQE, MetricXCOMET


@pytest.fixture(scope="session")
Expand All @@ -11,3 +12,11 @@ def metric_comet():
@pytest.fixture(scope="session")
def metric_cometqe():
return MetricCOMETQE(MetricCOMETQE.Config())


@pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA is not available on this machine."
)
@pytest.fixture(scope="session")
def metric_xcomet():
return MetricXCOMET(MetricXCOMET.Config())
2 changes: 1 addition & 1 deletion mbrs/decoders/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def decode(
DecoderRerank.Output: The n-best hypotheses.
"""
with timer.measure("rerank"):
scores = self.metric.scores(hypotheses, source)
scores = self.metric.scores(hypotheses, source=source)
topk_scores, topk_indices = self.metric.topk(scores, k=nbest)
return self.Output(
idx=topk_indices,
Expand Down
14 changes: 7 additions & 7 deletions mbrs/metrics/cometqe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch
from comet import download_model, load_from_checkpoint
from transformers import BatchEncoding

from mbrs import utils

from . import MetricReferenceless, register

Expand Down Expand Up @@ -38,11 +39,11 @@ def __init__(self, cfg: MetricCOMETQE.Config):
param.requires_grad = False

if not cfg.cpu and torch.cuda.is_available():
self.scorer = self.scorer.cuda()
if cfg.fp16:
self.scorer = self.scorer.half()
elif cfg.bf16:
self.scorer = self.scorer.bfloat16()
self.scorer = self.scorer.cuda()

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -74,9 +75,8 @@ def scores(self, hypotheses: list[str], source: str) -> torch.Tensor:
data = [{"src": source, "mt": hyp} for hyp in hypotheses]
scores = []
for i in range(0, len(data), self.cfg.batch_size):
batch = BatchEncoding(
self.scorer.prepare_for_inference(data[i : i + self.cfg.batch_size])[0]
).to(self.scorer.device)
model_output = self.scorer.predict_step((batch,))
batch = self.scorer.prepare_for_inference(data[i : i + self.cfg.batch_size])
batch = utils.to_device(batch, self.device)
model_output = self.scorer.predict_step(batch)
scores.append(model_output.scores)
return torch.cat(scores)
return torch.cat(scores).view(len(hypotheses))
53 changes: 23 additions & 30 deletions mbrs/metrics/xcomet.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,18 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Optional
from typing import Optional

import torch
from comet import download_model, load_from_checkpoint
from comet.models import XCOMETMetric
from torch import Tensor

from mbrs import timer
from mbrs import timer, utils

from . import Metric, register


def to_device(sample: Any, device: torch.device):
def _to_device(x):
if torch.is_tensor(x):
return x.to(device=device, non_blocking=True)
elif isinstance(x, dict):
return {key: _to_device(value) for key, value in x.items()}
elif isinstance(x, list):
return [_to_device(x) for x in x]
elif isinstance(x, tuple):
return tuple(_to_device(x) for x in x)
elif isinstance(x, set):
return {_to_device(x) for x in x}
else:
return x

return _to_device(sample)


@register("xcomet")
class MetricXCOMET(Metric):
"""XCOMET metric class."""
Expand Down Expand Up @@ -74,41 +56,52 @@ def device(self) -> torch.device:
return self.scorer.device

def score(
self, hypothesis: str, reference: str, source: Optional[str] = None
self,
hypothesis: str,
reference: Optional[str] = None,
source: Optional[str] = None,
) -> float:
"""Calculate the score of the given hypothesis.

Args:
hypothesis (str): A hypothesis.
reference (str): A reference.
reference (str, optional): A reference.
source (str, optional): A source.

Returns:
float: The score of the given hypothesis.
"""
inputs = {"mt": hypothesis, "ref": reference}
inputs = {"mt": hypothesis}
if reference is not None:
inputs["ref"] = reference
if source is not None:
inputs["src"] = source

batch = self.scorer.prepare_for_inference([inputs])
batch = to_device(batch, self.device)
batch = utils.to_device(batch, self.device)
model_output = self.scorer.predict_step(batch)
return model_output.scores.item()

def scores(
self, hypotheses: list[str], references: list[str], source: Optional[str] = None
self,
hypotheses: list[str],
references: Optional[list[str]] = None,
source: Optional[str] = None,
) -> Tensor:
"""Calculate the scores of the given hypothesis.

Args:
hypotheses (str): N hypotheses.
references (str): N references.
hypotheses (list[str]): N hypotheses.
references (list[str], optional): N references.
source (str, optional): A source.

Returns:
Tensor: The N scores of the given hypotheses.
"""
inputs = [{"mt": hyp, "ref": ref} for hyp, ref in zip(hypotheses, references)]
inputs = [{"mt": hyp} for hyp in hypotheses]
if references is not None:
for d, ref in zip(inputs, references):
d["ref"] = ref
if source is not None:
for d in inputs:
d["src"] = source
Expand All @@ -120,7 +113,7 @@ def scores(
batch = self.scorer.prepare_for_inference(
inputs[i : i + self.cfg.batch_size]
)
batch = to_device(batch, self.device)
batch = utils.to_device(batch, self.device)
model_output = self.scorer.predict_step(batch)
scores.append(model_output.scores)
return torch.cat(scores).view(len(hypotheses))
Expand Down Expand Up @@ -151,7 +144,7 @@ def pairwise_scores(
batch = self.scorer.prepare_for_inference(
data[i : i + self.cfg.batch_size]
)
batch = to_device(batch, self.device)
batch = utils.to_device(batch, self.device)
model_output = self.scorer.predict_step(batch)
scores.append(model_output.scores)
return torch.cat(scores).view(len(hypotheses), len(references))
75 changes: 65 additions & 10 deletions mbrs/metrics/xcomet_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,72 @@
import pytest
import torch

from .xcomet import to_device
from .xcomet import MetricXCOMET

SOURCE = "これはテストです"
HYPOTHESES = [
"this is a test",
"another test",
"this is a fest",
"Producția de zahăr primă va fi exprimată în ceea ce privește zahărul alb;",
]
REFERENCES = [
"ref",
"this is a test",
"producţia de zahăr brut se exprimă în zahăr alb;",
]
SCORES = torch.Tensor(
[
[0.97671, 1.00000, 0.49054],
[0.94399, 0.99120, 0.43007],
[0.71786, 0.71210, 0.30775],
[0.21788, 0.22079, 0.61004],
]
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available on this machine.")
def test_to_device():
device = torch.device("cuda:0")

for x in [1, 1.0, "a", True]:
assert to_device(x, device) == x
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA is not available on this machine."
)
class TestMetricXCOMET:
def test_score(self, metric_xcomet: MetricXCOMET):
for i, hyp in enumerate(HYPOTHESES):
for j, ref in enumerate(REFERENCES):
assert torch.isclose(
SCORES[i, j],
torch.tensor(metric_xcomet.score(hyp, ref, SOURCE)),
atol=0.0005 / 100,
)

assert to_device(torch.ones(1), device).device == device
assert to_device({"a": torch.ones(1)}, device)["a"].device == device
assert to_device([torch.ones(1)], device)[0].device == device
assert to_device((torch.ones(1),), device)[0].device == device
def test_scores(self, metric_xcomet: MetricXCOMET):
hyps = ["another test", "this is a test", "this is an test"]
refs = ["another test", "this is a fest", "this is a test"]
src = SOURCE

torch.testing.assert_close(
metric_xcomet.scores(hyps, refs, src).cpu().float(),
torch.FloatTensor([1.00000, 0.90545, 1.00000]),
atol=0.0005 / 100,
rtol=1e-6,
)
torch.testing.assert_close(
metric_xcomet.scores(hyps, source=src).cpu().float(),
torch.FloatTensor([0.99120, 0.99120, 0.99120]),
atol=0.0005 / 100,
rtol=1e-6,
)
torch.testing.assert_close(
metric_xcomet.scores(hyps, references=refs).cpu().float(),
torch.FloatTensor([1.00000, 0.77420, 1.00000]),
atol=0.0005 / 100,
rtol=1e-6,
)

def test_expected_scores(self, metric_xcomet: MetricXCOMET):
expected_scores = metric_xcomet.expected_scores(HYPOTHESES, REFERENCES, SOURCE)
torch.testing.assert_close(
expected_scores,
SCORES.mean(dim=1).to(metric_xcomet.device),
atol=0.0005 / 100,
rtol=1e-6,
)
21 changes: 21 additions & 0 deletions mbrs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any

import torch


def to_device(sample: Any, device: torch.device):
def _to_device(x):
if torch.is_tensor(x):
return x.to(device=device, non_blocking=True)
elif isinstance(x, dict):
return {key: _to_device(value) for key, value in x.items()}
elif isinstance(x, list):
return [_to_device(x) for x in x]
elif isinstance(x, tuple):
return tuple(_to_device(x) for x in x)
elif isinstance(x, set):
return {_to_device(x) for x in x}
else:
return x

return _to_device(sample)
19 changes: 19 additions & 0 deletions mbrs/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
import torch

from . import utils


@pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA is not available on this machine."
)
def test_to_device():
device = torch.device("cuda:0")

for x in [1, 1.0, "a", True]:
assert utils.to_device(x, device) == x

assert utils.to_device(torch.ones(1), device).device == device
assert utils.to_device({"a": torch.ones(1)}, device)["a"].device == device
assert utils.to_device([torch.ones(1)], device)[0].device == device
assert utils.to_device((torch.ones(1),), device)[0].device == device