-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from naist-nlp/pmbr
Implement Probabilistic MBR (Trabelsi et al., 2024)
- Loading branch information
Showing
10 changed files
with
535 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
from __future__ import annotations | ||
|
||
import math | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from mbrs import timer | ||
from mbrs.metrics import MetricCacheable | ||
from mbrs.modules.als import MatrixFactorizationALS | ||
|
||
from . import register | ||
from .mbr import DecoderMBR | ||
|
||
|
||
@register("pmbr") | ||
class DecoderProbabilisticMBR(DecoderMBR): | ||
"""Probablistic MBR decoder using alternating least squares (ALS) approximation. | ||
References: | ||
F. Trabelsi et al., 2024, | ||
"Efficient Minimum Bayes Risk Decoding using Low-Rank Matrix Completion Algorithms". | ||
https://arxiv.org/abs/2406.02832 | ||
""" | ||
|
||
cfg: Config | ||
|
||
@dataclass | ||
class Config(DecoderMBR.Config): | ||
"""Configuration for the decoder. | ||
- reduction_factor (float): Reduction factor. | ||
The computational budget will be reduced to `1 / reduction_factor`. | ||
- regularization_weight (float): Weight of L2 regularization. | ||
- rank (int): Rank of the factarized matrices. | ||
- niter (int): The number of alternating steps performed. | ||
- seed (int): Random seed. | ||
""" | ||
|
||
reduction_factor: float = 8 | ||
regularization_weight: float = 0.1 | ||
rank: int = 8 | ||
niter: int = 10 | ||
seed: int = 0 | ||
|
||
def expected_scores_probabilistic( | ||
self, | ||
hypotheses: list[str], | ||
references: list[str], | ||
source: Optional[str] = None, | ||
) -> torch.Tensor: | ||
"""Compute the expected scores using the probabilistic MBR algorithm. | ||
Args: | ||
hypotheses (list[str]): Hypotheses. | ||
references (list[str]): References. | ||
source (str, optional): A source. | ||
Returns: | ||
torch.Tensor: Expected scores of shape `(H,)`. | ||
""" | ||
rng = torch.Generator().manual_seed(self.cfg.seed) | ||
H = len(hypotheses) | ||
R = len(references) | ||
num_ucalcs = math.ceil(H * R / self.cfg.reduction_factor) | ||
|
||
pairwise_scores = torch.zeros((H, R), device=self.metric.device) | ||
pairwise_sample_indices = torch.randperm(H * R, generator=rng)[:num_ucalcs] | ||
hypothesis_sample_indices: list[int] = (pairwise_sample_indices // R).tolist() | ||
reference_sample_indices: list[int] = (pairwise_sample_indices % R).tolist() | ||
hypothesis_samples = [hypotheses[i] for i in hypothesis_sample_indices] | ||
reference_samples = [references[j] for j in reference_sample_indices] | ||
|
||
# For COMET-22 | ||
if isinstance(self.metric, MetricCacheable): | ||
hypothesis_sample_indices_set = set(hypothesis_sample_indices) | ||
reference_sample_indices_set = set(reference_sample_indices) | ||
hypothesis_samples_deduped = [ | ||
hypotheses[i] for i in hypothesis_sample_indices_set | ||
] | ||
reference_samples_deduped = [ | ||
references[j] for j in reference_sample_indices_set | ||
] | ||
with timer.measure("encode/hypotheses"): | ||
ir = self.metric.encode(hypothesis_samples_deduped) | ||
hypotheses_ir = ir.new_zeros((H, self.metric.embed_dim)) | ||
references_ir = ir.new_zeros((R, self.metric.embed_dim)) | ||
hypotheses_ir[list(hypothesis_sample_indices_set)] = ir | ||
with timer.measure("encode/references"): | ||
if hypotheses == references: | ||
seen_indices = list( | ||
hypothesis_sample_indices_set & reference_sample_indices_set | ||
) | ||
unseen_indices = list( | ||
reference_sample_indices_set - hypothesis_sample_indices_set | ||
) | ||
if len(seen_indices) > 0: | ||
references_ir[seen_indices] = hypotheses_ir[seen_indices] | ||
if len(unseen_indices) > 0: | ||
references_ir[unseen_indices] = self.metric.encode( | ||
[references[j] for j in unseen_indices] | ||
) | ||
else: | ||
references_ir[list(reference_sample_indices_set)] = ( | ||
self.metric.encode(reference_samples_deduped) | ||
) | ||
if source is None: | ||
source_ir = None | ||
else: | ||
with timer.measure("encode/source"): | ||
source_ir = self.metric.encode([source]) | ||
|
||
# Algorithm 2 in the paper. | ||
if isinstance(self.metric, MetricCacheable): | ||
for i in range(0, len(hypothesis_sample_indices), H): | ||
pairwise_scores[ | ||
hypothesis_sample_indices[i : i + H], | ||
reference_sample_indices[i : i + H], | ||
] = self.metric.scores_from_ir( | ||
hypotheses_ir[hypothesis_sample_indices[i : i + H]], | ||
references_ir[reference_sample_indices[i : i + H]], | ||
source_ir, | ||
).float() | ||
else: | ||
pairwise_scores[hypothesis_sample_indices, reference_sample_indices] = ( | ||
self.metric.scores( | ||
hypothesis_samples, reference_samples, source | ||
).float() | ||
) | ||
observed_mask = pairwise_scores.new_zeros((H, R), dtype=torch.bool) | ||
observed_mask[hypothesis_sample_indices, reference_sample_indices] = True | ||
|
||
# Algorithm 1 in the paper. | ||
als = MatrixFactorizationALS( | ||
regularization_weight=self.cfg.regularization_weight, rank=self.cfg.rank | ||
) | ||
X, Y = als.factorize( | ||
pairwise_scores, | ||
observed_mask=observed_mask, | ||
niter=self.cfg.niter, | ||
seed=self.cfg.seed, | ||
) | ||
pairwise_scores = X @ Y.T | ||
return pairwise_scores.mean(dim=-1) | ||
|
||
def decode( | ||
self, | ||
hypotheses: list[str], | ||
references: list[str], | ||
source: Optional[str] = None, | ||
nbest: int = 1, | ||
) -> DecoderMBR.Output: | ||
"""Select the n-best hypotheses based on the strategy. | ||
Args: | ||
hypotheses (list[str]): Hypotheses. | ||
references (list[str]): References. | ||
source (str, optional): A source. | ||
nbest (int): Return the n-best hypotheses. | ||
Returns: | ||
DecoderMBR.Output: The n-best hypotheses. | ||
""" | ||
|
||
if self.cfg.reduction_factor <= 1.0: | ||
expected_scores = self.metric.expected_scores( | ||
hypotheses, references, source | ||
) | ||
else: # Probabilistic MBR decoding | ||
expected_scores = self.expected_scores_probabilistic( | ||
hypotheses, references, source | ||
) | ||
topk_scores, topk_indices = self.metric.topk(expected_scores, k=nbest) | ||
return self.Output( | ||
idx=topk_indices, | ||
sentence=[hypotheses[idx] for idx in topk_indices], | ||
score=topk_scores, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy as np | ||
|
||
from mbrs.metrics.chrf import MetricChrF | ||
from mbrs.metrics.comet import MetricCOMET | ||
|
||
from .pmbr import DecoderProbabilisticMBR | ||
|
||
SOURCE = [ | ||
"これはテストです", | ||
"これはテストです", | ||
"これはテストです", | ||
"これはテストです", | ||
] | ||
HYPOTHESES = [ | ||
["another test", "this is a test", "this is a fest", "x", "this is test"], | ||
["another test", "this is a fest", "this is a test"], | ||
["this is a test"], | ||
["Producția de zahăr primă va fi exprimată în ceea ce privește zahărul alb;"], | ||
] | ||
REFERENCES = [ | ||
["another test", "this is a test", "this is a fest", "x", "this is test"], | ||
["this is a test", "ref", "these are tests", "this is the test"], | ||
["this is a test"], | ||
["producţia de zahăr brut se exprimă în zahăr alb;"], | ||
] | ||
|
||
BEST_INDICES = [1, 2, 0, 0] | ||
BEST_SENTENCES = [ | ||
"this is a test", | ||
"this is a test", | ||
"this is a test", | ||
"Producția de zahăr primă va fi exprimată în ceea ce privește zahărul alb;", | ||
] | ||
SCORES_COMET = np.array([0.84780, 0.85304, 0.99257, 0.78060], dtype=np.float32) | ||
SCORES_CHRF = np.array([48.912, 44.239, 100.0, 46.161], dtype=np.float32) | ||
|
||
NITER = 30 | ||
FACTOR = 1.25 | ||
RANK = 2 | ||
|
||
|
||
class TestDecoderProbabilisticMBR: | ||
def test_decode_chrf(self): | ||
metric = MetricChrF(MetricChrF.Config()) | ||
decoder = DecoderProbabilisticMBR( | ||
DecoderProbabilisticMBR.Config( | ||
reduction_factor=FACTOR, rank=RANK, niter=NITER | ||
), | ||
metric, | ||
) | ||
for i, (hyps, refs) in enumerate(zip(HYPOTHESES, REFERENCES)): | ||
output = decoder.decode(hyps, refs, SOURCE[i], nbest=1) | ||
assert output.idx[0] == BEST_INDICES[i] | ||
assert output.sentence[0] == BEST_SENTENCES[i] | ||
|
||
def test_decode_comet(self, metric_comet: MetricCOMET): | ||
decoder = DecoderProbabilisticMBR( | ||
DecoderProbabilisticMBR.Config( | ||
reduction_factor=FACTOR, rank=RANK, niter=NITER | ||
), | ||
metric_comet, | ||
) | ||
for i, (hyps, refs) in enumerate(zip(HYPOTHESES, REFERENCES)): | ||
output = decoder.decode(hyps, refs, SOURCE[i], nbest=1) | ||
assert output.idx[0] == BEST_INDICES[i] | ||
assert output.sentence[0] == BEST_SENTENCES[i] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.