Skip to content

Commit 5a526b8

Browse files
committed
fix gen config passing
1 parent d90f889 commit 5a526b8

File tree

4 files changed

+37
-28
lines changed

4 files changed

+37
-28
lines changed

colossalai/inference/core/engine.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -688,11 +688,12 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,
688688
)
689689

690690
batch_token_ids = None
691-
config_dict = self.generation_config.to_dict()
692-
# process repetition_penalty, no_repeat_ngram_size
693-
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
694-
if type in config_dict and config_dict[type] is not None:
695-
batch_token_ids = batch.batch_token_ids
691+
if (
692+
self.generation_config.repetition_penalty != 1.0
693+
or self.generation_config.no_repeat_ngram_size > 0
694+
or self.generation_config.forced_eos_token_id is not None
695+
):
696+
batch_token_ids = batch.batch_token_ids
696697

697698
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
698699
use_cuda_graph = False

colossalai/inference/core/request_handler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from colossalai.inference.config import InferenceConfig
99
from colossalai.inference.flash_decoding_utils import FDIntermTensors
1010
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
11-
from colossalai.inference.logit_processors import logits_processor
12-
from colossalai.inference.sampler import *
1311
from colossalai.inference.struct import RequestStatus, Sequence
1412
from colossalai.logging import get_dist_logger
1513

colossalai/inference/logit_processors.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,20 +134,20 @@ def apply_top_p(logits, top_p: float):
134134
def apply_forced_bos_token_id(
135135
logits: torch.Tensor,
136136
sequence_lengths: Union[torch.Tensor, List[int]],
137-
max_out_lengths: Union[torch.Tensor, List[int]],
137+
max_lengths: Union[torch.Tensor, List[int]],
138138
bos_token_id: int,
139139
):
140140
# NOTE For now, optimizations for encoder-decoder models have not been supported yet
141141
# And this function will never be called in the current implementation.
142142
if isinstance(sequence_lengths, torch.Tensor):
143143
sequence_lengths = sequence_lengths.tolist()
144-
if isinstance(max_out_lengths, torch.Tensor):
145-
max_out_lengths = max_out_lengths.tolist()
144+
if isinstance(max_lengths, torch.Tensor):
145+
max_lengths = max_lengths.tolist()
146146

147147
select_indexes = []
148148
num_sequences = logits.shape[0]
149149
sequence_lengths = sequence_lengths[:num_sequences]
150-
max_out_lengths = max_out_lengths[:num_sequences]
150+
max_lengths = max_lengths[:num_sequences]
151151
for i, sequence_length in enumerate(sequence_lengths):
152152
if sequence_length == 1:
153153
select_indexes.append(i)
@@ -162,7 +162,7 @@ def apply_forced_bos_token_id(
162162
def apply_forced_eos_token_id(
163163
logits: torch.Tensor,
164164
sequence_lengths: Union[torch.Tensor, List[int]],
165-
max_out_lengths: Union[torch.Tensor, List[int]],
165+
max_lengths: Union[torch.Tensor, List[int]],
166166
eos_token_id: Union[int, List[int]],
167167
):
168168
"""
@@ -172,22 +172,22 @@ def apply_forced_eos_token_id(
172172
173173
Args:
174174
logits(torch.Tensor): logits
175-
sequence_lengths(torch.Tensor): sequence lengths
176-
max_out_lengths(torch.Tensor): maximum output lengths for each sequence
175+
sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens
176+
max_lengths(torch.Tensor): the maximum length for each sequence
177177
eos_token_id(Union[int, List[int]]): forced eos token id
178178
"""
179179
if isinstance(eos_token_id, int):
180180
eos_token_id = [eos_token_id]
181181
if isinstance(sequence_lengths, torch.Tensor):
182182
sequence_lengths = sequence_lengths.tolist()
183-
if isinstance(max_out_lengths, torch.Tensor):
184-
max_out_lengths = max_out_lengths.tolist()
183+
if isinstance(max_lengths, torch.Tensor):
184+
max_lengths = max_lengths.tolist()
185185

186186
select_indexes = []
187187
num_sequences = logits.shape[0]
188188
sequence_lengths = sequence_lengths[:num_sequences]
189-
max_out_lengths = max_out_lengths[:num_sequences]
190-
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_out_lengths)):
189+
max_lengths = max_lengths[:num_sequences]
190+
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)):
191191
if sequence_length == max_out_length - 1:
192192
select_indexes.append(i)
193193
if select_indexes:

colossalai/inference/sampler.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from transformers.generation import GenerationConfig
55

6-
from colossalai.inference.logit_processors import logit_processor
6+
from colossalai.inference.logit_processors import get_logits_processor
77

88

99
def greedy_sample(
@@ -86,18 +86,28 @@ def search_tokens(
8686
Sample tokens for finished requests.
8787
"""
8888
# NOTE: need to decide the granularity to process logits (sequence or batch)
89+
print(
90+
f"CHECK search_tokens max_length {generation_config.max_length}; max_new_tokens {generation_config.max_new_tokens}"
91+
)
8992
config_dict = generation_config.to_dict()
90-
# process repetition_penalty, no_repeat_ngram_size
91-
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
92-
if type in config_dict and config_dict[type] is not None:
93-
logits = logit_processor(type, logits, config_dict[type], batch_token_ids)
93+
if (repetition_penalty := config_dict.get("repetition_penalty", 1.0)) != 1.0:
94+
logits = get_logits_processor("repetition_penalty", logits, repetition_penalty, batch_token_ids)
95+
if (no_repeat_ngram_size := config_dict.get("no_repeat_ngram_size", 0)) > 0:
96+
logits = get_logits_processor("no_repeat_ngram_size", logits, no_repeat_ngram_size, batch_token_ids)
97+
if (forced_eos_token_id := config_dict.get("forced_eos_token_id", None)) is not None:
98+
sequence_lengths = [len(batch_token_ids[i]) for i in range(len(batch_token_ids))]
99+
max_out_lengths = [generation_config.max_length for _ in range(len(batch_token_ids))]
100+
logits = get_logits_processor(
101+
"forced_eos_token_id", logits, sequence_lengths, max_out_lengths, forced_eos_token_id
102+
)
94103

95-
# do logit processor
96104
if generation_config.do_sample:
97-
# process temperature, top_k, top_p
98-
for type in ["temperature", "top_k", "top_p"]:
99-
if type in config_dict and config_dict[type] is not None:
100-
logits = logit_processor(type, logits, config_dict[type])
105+
if (temperature := config_dict.get("temperature", 1.0)) != 1.0:
106+
logits = get_logits_processor("temperature", logits, temperature)
107+
if (top_k := config_dict.get("top_k", 0)) != 0:
108+
logits = get_logits_processor("top_k", logits, top_k)
109+
if (top_p := config_dict.get("top_p", 1.0)) < 1.0:
110+
logits = get_logits_processor("top_p", logits, top_p)
101111

102112
# calculate probs
103113
probs = torch.softmax(logits, dim=-1, dtype=torch.float)

0 commit comments

Comments
 (0)