From ea60786638b330645a39f1d7ab95c78096e1a6fe Mon Sep 17 00:00:00 2001 From: Hiroyuki Deguchi Date: Tue, 2 Jul 2024 17:33:40 +0900 Subject: [PATCH 1/2] Support UnifiedMetric to calculate multiple scores in a single model --- mbrs/decoders/rerank.py | 2 +- mbrs/metrics/cometqe.py | 14 +++++----- mbrs/metrics/xcomet.py | 53 ++++++++++++++++--------------------- mbrs/metrics/xcomet_test.py | 18 +------------ mbrs/utils.py | 21 +++++++++++++++ mbrs/utils_test.py | 19 +++++++++++++ 6 files changed, 72 insertions(+), 55 deletions(-) create mode 100644 mbrs/utils.py create mode 100644 mbrs/utils_test.py diff --git a/mbrs/decoders/rerank.py b/mbrs/decoders/rerank.py index 6799c79..3d99191 100644 --- a/mbrs/decoders/rerank.py +++ b/mbrs/decoders/rerank.py @@ -27,7 +27,7 @@ def decode( DecoderRerank.Output: The n-best hypotheses. """ with timer.measure("rerank"): - scores = self.metric.scores(hypotheses, source) + scores = self.metric.scores(hypotheses, source=source) topk_scores, topk_indices = self.metric.topk(scores, k=nbest) return self.Output( idx=topk_indices, diff --git a/mbrs/metrics/cometqe.py b/mbrs/metrics/cometqe.py index 3db364f..b7ac239 100644 --- a/mbrs/metrics/cometqe.py +++ b/mbrs/metrics/cometqe.py @@ -4,7 +4,8 @@ import torch from comet import download_model, load_from_checkpoint -from transformers import BatchEncoding + +from mbrs import utils from . import MetricReferenceless, register @@ -38,11 +39,11 @@ def __init__(self, cfg: MetricCOMETQE.Config): param.requires_grad = False if not cfg.cpu and torch.cuda.is_available(): - self.scorer = self.scorer.cuda() if cfg.fp16: self.scorer = self.scorer.half() elif cfg.bf16: self.scorer = self.scorer.bfloat16() + self.scorer = self.scorer.cuda() @property def device(self) -> torch.device: @@ -74,9 +75,8 @@ def scores(self, hypotheses: list[str], source: str) -> torch.Tensor: data = [{"src": source, "mt": hyp} for hyp in hypotheses] scores = [] for i in range(0, len(data), self.cfg.batch_size): - batch = BatchEncoding( - self.scorer.prepare_for_inference(data[i : i + self.cfg.batch_size])[0] - ).to(self.scorer.device) - model_output = self.scorer.predict_step((batch,)) + batch = self.scorer.prepare_for_inference(data[i : i + self.cfg.batch_size]) + batch = utils.to_device(batch, self.device) + model_output = self.scorer.predict_step(batch) scores.append(model_output.scores) - return torch.cat(scores) + return torch.cat(scores).view(len(hypotheses)) diff --git a/mbrs/metrics/xcomet.py b/mbrs/metrics/xcomet.py index c022bb1..ccd529d 100644 --- a/mbrs/metrics/xcomet.py +++ b/mbrs/metrics/xcomet.py @@ -1,36 +1,18 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import torch from comet import download_model, load_from_checkpoint from comet.models import XCOMETMetric from torch import Tensor -from mbrs import timer +from mbrs import timer, utils from . import Metric, register -def to_device(sample: Any, device: torch.device): - def _to_device(x): - if torch.is_tensor(x): - return x.to(device=device, non_blocking=True) - elif isinstance(x, dict): - return {key: _to_device(value) for key, value in x.items()} - elif isinstance(x, list): - return [_to_device(x) for x in x] - elif isinstance(x, tuple): - return tuple(_to_device(x) for x in x) - elif isinstance(x, set): - return {_to_device(x) for x in x} - else: - return x - - return _to_device(sample) - - @register("xcomet") class MetricXCOMET(Metric): """XCOMET metric class.""" @@ -74,41 +56,52 @@ def device(self) -> torch.device: return self.scorer.device def score( - self, hypothesis: str, reference: str, source: Optional[str] = None + self, + hypothesis: str, + reference: Optional[str] = None, + source: Optional[str] = None, ) -> float: """Calculate the score of the given hypothesis. Args: hypothesis (str): A hypothesis. - reference (str): A reference. + reference (str, optional): A reference. source (str, optional): A source. Returns: float: The score of the given hypothesis. """ - inputs = {"mt": hypothesis, "ref": reference} + inputs = {"mt": hypothesis} + if reference is not None: + inputs["ref"] = reference if source is not None: inputs["src"] = source batch = self.scorer.prepare_for_inference([inputs]) - batch = to_device(batch, self.device) + batch = utils.to_device(batch, self.device) model_output = self.scorer.predict_step(batch) return model_output.scores.item() def scores( - self, hypotheses: list[str], references: list[str], source: Optional[str] = None + self, + hypotheses: list[str], + references: Optional[list[str]] = None, + source: Optional[str] = None, ) -> Tensor: """Calculate the scores of the given hypothesis. Args: - hypotheses (str): N hypotheses. - references (str): N references. + hypotheses (list[str]): N hypotheses. + references (list[str], optional): N references. source (str, optional): A source. Returns: Tensor: The N scores of the given hypotheses. """ - inputs = [{"mt": hyp, "ref": ref} for hyp, ref in zip(hypotheses, references)] + inputs = [{"mt": hyp} for hyp in hypotheses] + if references is not None: + for d, ref in zip(inputs, references): + d["ref"] = ref if source is not None: for d in inputs: d["src"] = source @@ -120,7 +113,7 @@ def scores( batch = self.scorer.prepare_for_inference( inputs[i : i + self.cfg.batch_size] ) - batch = to_device(batch, self.device) + batch = utils.to_device(batch, self.device) model_output = self.scorer.predict_step(batch) scores.append(model_output.scores) return torch.cat(scores).view(len(hypotheses)) @@ -151,7 +144,7 @@ def pairwise_scores( batch = self.scorer.prepare_for_inference( data[i : i + self.cfg.batch_size] ) - batch = to_device(batch, self.device) + batch = utils.to_device(batch, self.device) model_output = self.scorer.predict_step(batch) scores.append(model_output.scores) return torch.cat(scores).view(len(hypotheses), len(references)) diff --git a/mbrs/metrics/xcomet_test.py b/mbrs/metrics/xcomet_test.py index 5e736dd..254b38a 100644 --- a/mbrs/metrics/xcomet_test.py +++ b/mbrs/metrics/xcomet_test.py @@ -1,17 +1 @@ -import pytest -import torch - -from .xcomet import to_device - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available on this machine.") -def test_to_device(): - device = torch.device("cuda:0") - - for x in [1, 1.0, "a", True]: - assert to_device(x, device) == x - - assert to_device(torch.ones(1), device).device == device - assert to_device({"a": torch.ones(1)}, device)["a"].device == device - assert to_device([torch.ones(1)], device)[0].device == device - assert to_device((torch.ones(1),), device)[0].device == device +# TODO(deguchi): Add unit tests for XCOMET diff --git a/mbrs/utils.py b/mbrs/utils.py new file mode 100644 index 0000000..ddf6ec5 --- /dev/null +++ b/mbrs/utils.py @@ -0,0 +1,21 @@ +from typing import Any + +import torch + + +def to_device(sample: Any, device: torch.device): + def _to_device(x): + if torch.is_tensor(x): + return x.to(device=device, non_blocking=True) + elif isinstance(x, dict): + return {key: _to_device(value) for key, value in x.items()} + elif isinstance(x, list): + return [_to_device(x) for x in x] + elif isinstance(x, tuple): + return tuple(_to_device(x) for x in x) + elif isinstance(x, set): + return {_to_device(x) for x in x} + else: + return x + + return _to_device(sample) diff --git a/mbrs/utils_test.py b/mbrs/utils_test.py new file mode 100644 index 0000000..821e27c --- /dev/null +++ b/mbrs/utils_test.py @@ -0,0 +1,19 @@ +import pytest +import torch + +from . import utils + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA is not available on this machine." +) +def test_to_device(): + device = torch.device("cuda:0") + + for x in [1, 1.0, "a", True]: + assert utils.to_device(x, device) == x + + assert utils.to_device(torch.ones(1), device).device == device + assert utils.to_device({"a": torch.ones(1)}, device)["a"].device == device + assert utils.to_device([torch.ones(1)], device)[0].device == device + assert utils.to_device((torch.ones(1),), device)[0].device == device From 930991936be0af4fd2ffccea1e924f032609b4cf Mon Sep 17 00:00:00 2001 From: Hiroyuki Deguchi Date: Mon, 8 Jul 2024 17:48:24 +0900 Subject: [PATCH 2/2] Add unit tests for XCOMET --- mbrs/conftest.py | 11 +++++- mbrs/metrics/xcomet_test.py | 73 ++++++++++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/mbrs/conftest.py b/mbrs/conftest.py index e013264..f6132a6 100644 --- a/mbrs/conftest.py +++ b/mbrs/conftest.py @@ -1,6 +1,7 @@ import pytest +import torch -from mbrs.metrics import MetricCOMET, MetricCOMETQE +from mbrs.metrics import MetricCOMET, MetricCOMETQE, MetricXCOMET @pytest.fixture(scope="session") @@ -11,3 +12,11 @@ def metric_comet(): @pytest.fixture(scope="session") def metric_cometqe(): return MetricCOMETQE(MetricCOMETQE.Config()) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA is not available on this machine." +) +@pytest.fixture(scope="session") +def metric_xcomet(): + return MetricXCOMET(MetricXCOMET.Config()) diff --git a/mbrs/metrics/xcomet_test.py b/mbrs/metrics/xcomet_test.py index 254b38a..1d5b0d0 100644 --- a/mbrs/metrics/xcomet_test.py +++ b/mbrs/metrics/xcomet_test.py @@ -1 +1,72 @@ -# TODO(deguchi): Add unit tests for XCOMET +import pytest +import torch + +from .xcomet import MetricXCOMET + +SOURCE = "これはテストです" +HYPOTHESES = [ + "this is a test", + "another test", + "this is a fest", + "Producția de zahăr primă va fi exprimată în ceea ce privește zahărul alb;", +] +REFERENCES = [ + "ref", + "this is a test", + "producţia de zahăr brut se exprimă în zahăr alb;", +] +SCORES = torch.Tensor( + [ + [0.97671, 1.00000, 0.49054], + [0.94399, 0.99120, 0.43007], + [0.71786, 0.71210, 0.30775], + [0.21788, 0.22079, 0.61004], + ] +) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA is not available on this machine." +) +class TestMetricXCOMET: + def test_score(self, metric_xcomet: MetricXCOMET): + for i, hyp in enumerate(HYPOTHESES): + for j, ref in enumerate(REFERENCES): + assert torch.isclose( + SCORES[i, j], + torch.tensor(metric_xcomet.score(hyp, ref, SOURCE)), + atol=0.0005 / 100, + ) + + def test_scores(self, metric_xcomet: MetricXCOMET): + hyps = ["another test", "this is a test", "this is an test"] + refs = ["another test", "this is a fest", "this is a test"] + src = SOURCE + + torch.testing.assert_close( + metric_xcomet.scores(hyps, refs, src).cpu().float(), + torch.FloatTensor([1.00000, 0.90545, 1.00000]), + atol=0.0005 / 100, + rtol=1e-6, + ) + torch.testing.assert_close( + metric_xcomet.scores(hyps, source=src).cpu().float(), + torch.FloatTensor([0.99120, 0.99120, 0.99120]), + atol=0.0005 / 100, + rtol=1e-6, + ) + torch.testing.assert_close( + metric_xcomet.scores(hyps, references=refs).cpu().float(), + torch.FloatTensor([1.00000, 0.77420, 1.00000]), + atol=0.0005 / 100, + rtol=1e-6, + ) + + def test_expected_scores(self, metric_xcomet: MetricXCOMET): + expected_scores = metric_xcomet.expected_scores(HYPOTHESES, REFERENCES, SOURCE) + torch.testing.assert_close( + expected_scores, + SCORES.mean(dim=1).to(metric_xcomet.device), + atol=0.0005 / 100, + rtol=1e-6, + )