Skip to content

Commit

Permalink
Adapter new type promotion rule for Paddle 2.6 (#8421)
Browse files Browse the repository at this point in the history
  • Loading branch information
zxcd authored May 13, 2024
1 parent fe6277f commit d9f555e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ def prepare_attention_mask_for_generation(self, input_ids, pad_token_id, eos_tok
def update_scores_for_generation(self, scores, next_scores, length, unfinished_flag):
# update scores

unfinished_scores = (scores * length + next_scores) / (length + 1)
unfinished_scores = (scores * length.astype(scores.dtype) + next_scores) / (length.astype(scores.dtype) + 1)
scores = paddle.where(unfinished_flag, unfinished_scores, scores)
return scores

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ def prepare_attention_mask_for_generation(self, input_ids, pad_token_id, eos_tok
def update_scores_for_generation(self, scores, next_scores, length, unfinished_flag):
# update scores

unfinished_scores = (scores * length + next_scores) / (length + 1)
unfinished_scores = (scores * length.astype(scores.dtype) + next_scores) / (length.astype(scores.dtype) + 1)
scores = paddle.where(unfinished_flag, unfinished_scores, scores)
return scores

Expand Down

0 comments on commit d9f555e

Please sign in to comment.