Skip to content

Commit

Permalink
Merge pull request #13 from naist-nlp/speedup-bleurt
Browse files Browse the repository at this point in the history
Speed up pairwise_scores of BLEURT
  • Loading branch information
de9uch1 authored Jul 27, 2024
2 parents 2aa1c21 + dd077cd commit 1a88001
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions mbrs/metrics/bleurt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class MetricBLEURT(Metric):
(thanks to @lucadiliello)
Available checkpoints:
- lucadiliello/BLEURT-20
- lucadiliello/BLEURT-20-D12
- lucadiliello/BLEURT-20-D3
Expand Down Expand Up @@ -140,24 +141,27 @@ def pairwise_scores(
of hypotheses and `R` is the number of references.
"""
scores = []
pairwise_iter = itertools.product(hypotheses, references)
while batch := tuple(
zip(*itertools.islice(pairwise_iter, self.cfg.batch_size))
):
hypotheses_ids = [
self.tokenizer.encode(h, add_special_tokens=False) for h in hypotheses
]
references_ids = [
self.tokenizer.encode(r, add_special_tokens=False) for r in references
]
pairwise_iter = itertools.product(references_ids, hypotheses_ids)

while batch := list(itertools.islice(pairwise_iter, self.cfg.batch_size)):
with timer.measure("score") as t:
hyps, refs = batch
t.set_delta_ncalls(len(hyps))
batch = self.tokenizer(
refs,
hyps,
t.set_delta_ncalls(len(batch))
batch = self.tokenizer.batch_encode_plus(
batch,
truncation=True,
padding=True,
max_length=self.max_length,
return_tensors="pt",
).to(self.device)
model_output = self.scorer(**batch)
scores.append(model_output.logits.flatten())
return torch.cat(scores).view(len(hypotheses), len(references))
return torch.cat(scores).view(len(references), len(hypotheses)).transpose(0, 1)

def corpus_score(self, hypotheses: list[str], references: list[str]) -> float:
"""Calculate the corpus-level score.
Expand Down

0 comments on commit 1a88001

Please sign in to comment.