Skip to content

Commit

Permalink
generator
Browse files Browse the repository at this point in the history
  • Loading branch information
实一 committed Dec 14, 2022
1 parent 601f05b commit ca89b2c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions models/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,14 @@ def _generate(
)
probs = probs[:, -1, :] * self.lm_weight

if probs.shape[1] > lprobs.shape[1]:
probs = probs[:, :lprobs.shape[1]]
# align lm vocab with ofa vocab
if probs.shape[1] < lprobs.shape[1]:
probs = torch.cat([probs, probs.new_zeros([probs.size(0), lprobs.shape[1] - probs.shape[1]],
dtype=torch.float64)], dim=1)
probs[:, self.constraint_end:] = -math.inf
probs[:, 4:self.constraint_start] = -math.inf
elif probs.shape[1] > lprobs.shape[1]:
raise NotImplementedError

lprobs += probs
# handle prefix tokens (possibly with different lengths)
Expand Down

0 comments on commit ca89b2c

Please sign in to comment.