diff --git a/mbrs/metrics/bleurt.py b/mbrs/metrics/bleurt.py index 9c6048d..6173c15 100644 --- a/mbrs/metrics/bleurt.py +++ b/mbrs/metrics/bleurt.py @@ -24,6 +24,7 @@ class MetricBLEURT(Metric): (thanks to @lucadiliello) Available checkpoints: + - lucadiliello/BLEURT-20 - lucadiliello/BLEURT-20-D12 - lucadiliello/BLEURT-20-D3 @@ -140,16 +141,19 @@ def pairwise_scores( of hypotheses and `R` is the number of references. """ scores = [] - pairwise_iter = itertools.product(hypotheses, references) - while batch := tuple( - zip(*itertools.islice(pairwise_iter, self.cfg.batch_size)) - ): + hypotheses_ids = [ + self.tokenizer.encode(h, add_special_tokens=False) for h in hypotheses + ] + references_ids = [ + self.tokenizer.encode(r, add_special_tokens=False) for r in references + ] + pairwise_iter = itertools.product(references_ids, hypotheses_ids) + + while batch := list(itertools.islice(pairwise_iter, self.cfg.batch_size)): with timer.measure("score") as t: - hyps, refs = batch - t.set_delta_ncalls(len(hyps)) - batch = self.tokenizer( - refs, - hyps, + t.set_delta_ncalls(len(batch)) + batch = self.tokenizer.batch_encode_plus( + batch, truncation=True, padding=True, max_length=self.max_length, @@ -157,7 +161,7 @@ def pairwise_scores( ).to(self.device) model_output = self.scorer(**batch) scores.append(model_output.logits.flatten()) - return torch.cat(scores).view(len(hypotheses), len(references)) + return torch.cat(scores).view(len(references), len(hypotheses)).transpose(0, 1) def corpus_score(self, hypotheses: list[str], references: list[str]) -> float: """Calculate the corpus-level score.