Skip to content

Commit

Permalink
Merge pull request #15 from naist-nlp/fix-mp-bleu
Browse files Browse the repository at this point in the history
Fix a bug of multiprocessing in mecab-python3 tokenizer used in BLEU
  • Loading branch information
de9uch1 authored Jul 29, 2024
2 parents 40adf8e + 3988af1 commit 8c1f5c4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
38 changes: 35 additions & 3 deletions mbrs/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Config(MetricAggregatable.Config):
effective_order: bool = True
trg_lang: str = ""

cfg: Config

@dataclass
class AggregatedReference:
"""Aggregated reference representation.
Expand All @@ -56,7 +58,12 @@ class AggregatedReference:
length: float

def __init__(self, cfg: MetricBLEU.Config):
self.scorer = BLEU(
super().__init__(cfg)
self.scorer = self.__initialize_scorer(cfg)

@staticmethod
def __initialize_scorer(cfg: MetricBLEU.Config) -> BLEU:
scorer = BLEU(
lowercase=cfg.lowercase,
force=cfg.force,
tokenize=cfg.tokenize,
Expand All @@ -66,6 +73,8 @@ def __init__(self, cfg: MetricBLEU.Config):
effective_order=cfg.effective_order,
trg_lang=cfg.trg_lang,
)
MetricBLEU._score_worker.scorer = scorer
return scorer

def score(self, hypothesis: str, reference: str, *_) -> float:
"""Calculate the score of the given hypothesis.
Expand All @@ -79,6 +88,27 @@ def score(self, hypothesis: str, reference: str, *_) -> float:
"""
return self.scorer.sentence_score(hypothesis, [reference]).score

@staticmethod
def _score_worker(hypothesis: str, reference: str, *_) -> float:
"""Calculate the score of the given hypothesis.
Beacause ja-mecab tokenizer cannot be pickled, this method is necessary to use
multiprocessing.
Args:
hypothesis (str): Hypothesis.
reference (str): Reference.
Returns:
float: The score of the given hypothesis.
Todo:
- Replace this method with a better logic.
"""
return MetricBLEU._score_worker.scorer.sentence_score(
hypothesis, [reference]
).score

def pairwise_scores(
self, hypotheses: list[str], references: list[str], *_
) -> Tensor:
Expand All @@ -92,14 +122,16 @@ def pairwise_scores(
Tensor: Score matrix of shape `(H, R)`, where `H` is the number
of hypotheses and `R` is the number of references.
"""
with concurrent.futures.ProcessPoolExecutor() as executor:
with concurrent.futures.ProcessPoolExecutor(
initializer=self.__initialize_scorer, initargs=(self.cfg,)
) as executor:
with timer.measure("score") as t:
t.set_delta_ncalls(len(hypotheses) * len(references))

return Tensor(
list(
executor.map(
self.score,
self._score_worker,
*zip(*itertools.product(hypotheses, references)),
chunksize=len(hypotheses),
)
Expand Down
11 changes: 11 additions & 0 deletions mbrs/metrics/bleu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def test_expected_scores(self, effective_order: bool):
rtol=1e-4,
)

def test_expected_scores_ja(self):
metric = MetricBLEU(
MetricBLEU.Config(tokenize="ja-mecab", effective_order=True)
)
hyps = ["ありがとうございます", "どうも"]
refs = ["ありがとう", "どうもありがとうございます"]
expected_scores = metric.expected_scores(hyps, refs)
torch.testing.assert_close(
expected_scores, torch.Tensor([49.5846, 2.4894]), atol=0.0005, rtol=1e-4
)

def test_corpus_score(self):
hyps = [
"this is a test",
Expand Down

0 comments on commit 8c1f5c4

Please sign in to comment.