Skip to content

Commit b06efbb

Browse files
committed
fix no_repeat_ngram_size_logit_process
1 parent 2d94b8c commit b06efbb

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

colossalai/inference/logit_processors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids:
3636
batch_size = len(batch_token_ids)
3737

3838
for batch_id in range(batch_size):
39+
current_token_ids = batch_token_ids[batch_id]
3940
current_len = current_token_ids.size(0)
4041
if current_len + 1 < ngram_size:
4142
continue
4243

43-
current_token_ids = batch_token_ids[batch_id]
4444
token_ids_list = current_token_ids.tolist()
4545

4646
ngrams_dict = {}

0 commit comments

Comments
 (0)