Skip to content

Commit

Permalink
Merge pull request #3 from naist-nlp/pmbr
Browse files Browse the repository at this point in the history
Implement Probabilistic MBR (Trabelsi et al., 2024)
  • Loading branch information
de9uch1 authored Jul 8, 2024
2 parents c08b72c + 1b361d5 commit 4448dfc
Show file tree
Hide file tree
Showing 10 changed files with 535 additions and 12 deletions.
19 changes: 10 additions & 9 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ List of implemented methods

Currently, the following metrics are supported:

- BLEU `(Papineni et al., 2002) <https://aclanthology.org/P02-1040>`_
- TER `(Snover et al., 2006) <https://aclanthology.org/2006.amta-papers.25>`_
- chrF `(Popović et al., 2015) <https://aclanthology.org/W15-3049>`_
- COMET `(Rei et al., 2020) <https://aclanthology.org/2020.emnlp-main.213>`_
- COMETkiwi `(Rei et al., 2022) <https://aclanthology.org/2022.wmt-1.60>`_
- XCOMET `(Guerreiro et al., 2023) <https://arxiv.org/abs/2310.10482>`_
- BLEU: :code:`bleu` `(Papineni et al., 2002) <https://aclanthology.org/P02-1040>`_
- TER: :code:`ter` `(Snover et al., 2006) <https://aclanthology.org/2006.amta-papers.25>`_
- chrF: :code:`chrf` `(Popović et al., 2015) <https://aclanthology.org/W15-3049>`_
- COMET: :code:`comet` `(Rei et al., 2020) <https://aclanthology.org/2020.emnlp-main.213>`_
- COMETkiwi: :code:`cometqe` `(Rei et al., 2022) <https://aclanthology.org/2022.wmt-1.60>`_
- XCOMET: :code:`xcomet` `(Guerreiro et al., 2023) <https://arxiv.org/abs/2310.10482>`_

The following decoding methods are implemented:

Expand All @@ -116,9 +116,10 @@ The following decoding methods are implemented:

Specifically, the following methods of MBR decoding are included:

- Monte Carlo estimation (`Eikema and Aziz, 2020 <https://aclanthology.org/2020.coling-main.398>`_; `Eikema and Aziz, 2022 <https://aclanthology.org/2022.emnlp-main.754>`_)
- Confidence-based pruning `(Cheng and Vlachos, 2023) <https://aclanthology.org/2023.emnlp-main.767>`_
- Centroid-based MBR `(Deguchi et al., 2024) <https://arxiv.org/abs/2402.11197>`_
- Monte Carlo estimation: :code:`mbr` (`Eikema and Aziz, 2020 <https://aclanthology.org/2020.coling-main.398>`_; `Eikema and Aziz, 2022 <https://aclanthology.org/2022.emnlp-main.754>`_)
- Confidence-based pruning: :code:`pruning_mbr` `(Cheng and Vlachos, 2023) <https://aclanthology.org/2023.emnlp-main.767>`_
- Centroid-based MBR: :code:`cbmbr` `(Deguchi et al., 2024) <https://arxiv.org/abs/2402.11197>`_
- Probabilistic MBR: :code:`pmbr` `(Trabelsi et al., 2024) <https://arxiv.org/abs/2406.02832>`_

Citation
========
Expand Down
4 changes: 3 additions & 1 deletion mbrs/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

register, get_decoder = registry.setup("decoder")

from .cbmbr import DecoderCBMBR
from .mbr import DecoderMBR
from .cbmbr import DecoderCBMBR
from .pmbr import DecoderProbabilisticMBR
from .pruning_mbr import DecoderPruningMBR
from .rerank import DecoderRerank

Expand All @@ -14,6 +15,7 @@
"DecoderReferenceless",
"DecoderMBR",
"DecoderCBMBR",
"DecoderProbabilisticMBR",
"DecoderPruningMBR",
"DecoderRerank",
]
179 changes: 179 additions & 0 deletions mbrs/decoders/pmbr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Optional

import torch

from mbrs import timer
from mbrs.metrics import MetricCacheable
from mbrs.modules.als import MatrixFactorizationALS

from . import register
from .mbr import DecoderMBR


@register("pmbr")
class DecoderProbabilisticMBR(DecoderMBR):
"""Probablistic MBR decoder using alternating least squares (ALS) approximation.
References:
F. Trabelsi et al., 2024,
"Efficient Minimum Bayes Risk Decoding using Low-Rank Matrix Completion Algorithms".
https://arxiv.org/abs/2406.02832
"""

cfg: Config

@dataclass
class Config(DecoderMBR.Config):
"""Configuration for the decoder.
- reduction_factor (float): Reduction factor.
The computational budget will be reduced to `1 / reduction_factor`.
- regularization_weight (float): Weight of L2 regularization.
- rank (int): Rank of the factarized matrices.
- niter (int): The number of alternating steps performed.
- seed (int): Random seed.
"""

reduction_factor: float = 8
regularization_weight: float = 0.1
rank: int = 8
niter: int = 10
seed: int = 0

def expected_scores_probabilistic(
self,
hypotheses: list[str],
references: list[str],
source: Optional[str] = None,
) -> torch.Tensor:
"""Compute the expected scores using the probabilistic MBR algorithm.
Args:
hypotheses (list[str]): Hypotheses.
references (list[str]): References.
source (str, optional): A source.
Returns:
torch.Tensor: Expected scores of shape `(H,)`.
"""
rng = torch.Generator().manual_seed(self.cfg.seed)
H = len(hypotheses)
R = len(references)
num_ucalcs = math.ceil(H * R / self.cfg.reduction_factor)

pairwise_scores = torch.zeros((H, R), device=self.metric.device)
pairwise_sample_indices = torch.randperm(H * R, generator=rng)[:num_ucalcs]
hypothesis_sample_indices: list[int] = (pairwise_sample_indices // R).tolist()
reference_sample_indices: list[int] = (pairwise_sample_indices % R).tolist()
hypothesis_samples = [hypotheses[i] for i in hypothesis_sample_indices]
reference_samples = [references[j] for j in reference_sample_indices]

# For COMET-22
if isinstance(self.metric, MetricCacheable):
hypothesis_sample_indices_set = set(hypothesis_sample_indices)
reference_sample_indices_set = set(reference_sample_indices)
hypothesis_samples_deduped = [
hypotheses[i] for i in hypothesis_sample_indices_set
]
reference_samples_deduped = [
references[j] for j in reference_sample_indices_set
]
with timer.measure("encode/hypotheses"):
ir = self.metric.encode(hypothesis_samples_deduped)
hypotheses_ir = ir.new_zeros((H, self.metric.embed_dim))
references_ir = ir.new_zeros((R, self.metric.embed_dim))
hypotheses_ir[list(hypothesis_sample_indices_set)] = ir
with timer.measure("encode/references"):
if hypotheses == references:
seen_indices = list(
hypothesis_sample_indices_set & reference_sample_indices_set
)
unseen_indices = list(
reference_sample_indices_set - hypothesis_sample_indices_set
)
if len(seen_indices) > 0:
references_ir[seen_indices] = hypotheses_ir[seen_indices]
if len(unseen_indices) > 0:
references_ir[unseen_indices] = self.metric.encode(
[references[j] for j in unseen_indices]
)
else:
references_ir[list(reference_sample_indices_set)] = (
self.metric.encode(reference_samples_deduped)
)
if source is None:
source_ir = None
else:
with timer.measure("encode/source"):
source_ir = self.metric.encode([source])

# Algorithm 2 in the paper.
if isinstance(self.metric, MetricCacheable):
for i in range(0, len(hypothesis_sample_indices), H):
pairwise_scores[
hypothesis_sample_indices[i : i + H],
reference_sample_indices[i : i + H],
] = self.metric.scores_from_ir(
hypotheses_ir[hypothesis_sample_indices[i : i + H]],
references_ir[reference_sample_indices[i : i + H]],
source_ir,
).float()
else:
pairwise_scores[hypothesis_sample_indices, reference_sample_indices] = (
self.metric.scores(
hypothesis_samples, reference_samples, source
).float()
)
observed_mask = pairwise_scores.new_zeros((H, R), dtype=torch.bool)
observed_mask[hypothesis_sample_indices, reference_sample_indices] = True

# Algorithm 1 in the paper.
als = MatrixFactorizationALS(
regularization_weight=self.cfg.regularization_weight, rank=self.cfg.rank
)
X, Y = als.factorize(
pairwise_scores,
observed_mask=observed_mask,
niter=self.cfg.niter,
seed=self.cfg.seed,
)
pairwise_scores = X @ Y.T
return pairwise_scores.mean(dim=-1)

def decode(
self,
hypotheses: list[str],
references: list[str],
source: Optional[str] = None,
nbest: int = 1,
) -> DecoderMBR.Output:
"""Select the n-best hypotheses based on the strategy.
Args:
hypotheses (list[str]): Hypotheses.
references (list[str]): References.
source (str, optional): A source.
nbest (int): Return the n-best hypotheses.
Returns:
DecoderMBR.Output: The n-best hypotheses.
"""

if self.cfg.reduction_factor <= 1.0:
expected_scores = self.metric.expected_scores(
hypotheses, references, source
)
else: # Probabilistic MBR decoding
expected_scores = self.expected_scores_probabilistic(
hypotheses, references, source
)
topk_scores, topk_indices = self.metric.topk(expected_scores, k=nbest)
return self.Output(
idx=topk_indices,
sentence=[hypotheses[idx] for idx in topk_indices],
score=topk_scores,
)
66 changes: 66 additions & 0 deletions mbrs/decoders/pmbr_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np

from mbrs.metrics.chrf import MetricChrF
from mbrs.metrics.comet import MetricCOMET

from .pmbr import DecoderProbabilisticMBR

SOURCE = [
"これはテストです",
"これはテストです",
"これはテストです",
"これはテストです",
]
HYPOTHESES = [
["another test", "this is a test", "this is a fest", "x", "this is test"],
["another test", "this is a fest", "this is a test"],
["this is a test"],
["Producția de zahăr primă va fi exprimată în ceea ce privește zahărul alb;"],
]
REFERENCES = [
["another test", "this is a test", "this is a fest", "x", "this is test"],
["this is a test", "ref", "these are tests", "this is the test"],
["this is a test"],
["producţia de zahăr brut se exprimă în zahăr alb;"],
]

BEST_INDICES = [1, 2, 0, 0]
BEST_SENTENCES = [
"this is a test",
"this is a test",
"this is a test",
"Producția de zahăr primă va fi exprimată în ceea ce privește zahărul alb;",
]
SCORES_COMET = np.array([0.84780, 0.85304, 0.99257, 0.78060], dtype=np.float32)
SCORES_CHRF = np.array([48.912, 44.239, 100.0, 46.161], dtype=np.float32)

NITER = 30
FACTOR = 1.25
RANK = 2


class TestDecoderProbabilisticMBR:
def test_decode_chrf(self):
metric = MetricChrF(MetricChrF.Config())
decoder = DecoderProbabilisticMBR(
DecoderProbabilisticMBR.Config(
reduction_factor=FACTOR, rank=RANK, niter=NITER
),
metric,
)
for i, (hyps, refs) in enumerate(zip(HYPOTHESES, REFERENCES)):
output = decoder.decode(hyps, refs, SOURCE[i], nbest=1)
assert output.idx[0] == BEST_INDICES[i]
assert output.sentence[0] == BEST_SENTENCES[i]

def test_decode_comet(self, metric_comet: MetricCOMET):
decoder = DecoderProbabilisticMBR(
DecoderProbabilisticMBR.Config(
reduction_factor=FACTOR, rank=RANK, niter=NITER
),
metric_comet,
)
for i, (hyps, refs) in enumerate(zip(HYPOTHESES, REFERENCES)):
output = decoder.decode(hyps, refs, SOURCE[i], nbest=1)
assert output.idx[0] == BEST_INDICES[i]
assert output.sentence[0] == BEST_SENTENCES[i]
70 changes: 70 additions & 0 deletions mbrs/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,28 @@ def score(
float: The score of the given hypothesis.
"""

def scores(
self, hypotheses: list[str], references: list[str], source: Optional[str] = None
) -> Tensor:
"""Calculate the scores of the given hypotheses.
Args:
hypotheses (str): N hypotheses.
references (str): N references.
source (str, optional): A source.
Returns:
Tensor: The N scores of the given hypotheses.
"""
with timer.measure("score") as t:
t.set_delta_ncalls(len(hypotheses))
return Tensor(
[
self.score(hyp, ref, source)
for hyp, ref in zip(hypotheses, references)
]
)

def pairwise_scores(
self, hypotheses: list[str], references: list[str], source: Optional[str] = None
) -> Tensor:
Expand Down Expand Up @@ -118,6 +140,11 @@ class MetricCacheable(Metric, metaclass=abc.ABCMeta):
This class supports to cache intermediate representations of the encoder."""

@property
@abc.abstractmethod
def embed_dim(self) -> int:
"""Return the size of embedding dimension."""

@abc.abstractmethod
def encode(self, sentences: list[str]) -> Tensor:
"""Encode the given sentences into their intermediate representations.
Expand Down Expand Up @@ -170,6 +197,49 @@ def score(
self.encode([source]) if source is not None else None,
).item()

def scores_from_ir(
self,
hypotheses_ir: Tensor,
references_ir: Tensor,
source_ir: Optional[Tensor] = None,
) -> Tensor:
"""Calculate the scores of the given hypotheses from the intermediate representations.
Args:
hypotheses_ir (str): N hypotheses.
references_ir (str): N references.
source_ir (str, optional): A source.
Returns:
Tensor: The N scores of the given hypotheses.
"""
H = len(hypotheses_ir)
if source_ir is not None:
source_ir = source_ir.repeat(H, 1)
with timer.measure("score") as t:
t.set_delta_ncalls(H)
return self.out_proj(hypotheses_ir, references_ir, source_ir)

def scores(
self, hypotheses: list[str], references: list[str], source: Optional[str] = None
) -> Tensor:
"""Calculate the scores of the given hypotheses.
Args:
hypotheses (str): N hypotheses.
references (str): N references.
source (str, optional): A source.
Returns:
Tensor: The N scores of the given hypotheses.
"""

return self.scores_from_ir(
self.encode(hypotheses),
self.encode(references),
self.encode([source]) if source is not None else None,
)

def pairwise_scores_from_ir(
self,
hypotheses_ir: Tensor,
Expand Down
Loading

0 comments on commit 4448dfc

Please sign in to comment.