Skip to content

Commit 3f1fc85

Browse files
committed
refactor and add
1 parent bfad393 commit 3f1fc85

File tree

4 files changed

+115
-26
lines changed

4 files changed

+115
-26
lines changed

colossalai/inference/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class InferenceConfig:
137137
top_k: Optional[int] = None
138138
top_p: Optional[float] = None
139139
min_p: Optional[float] = None
140+
forced_eos_token_id: int = None
140141

141142
# speculative decoding configs
142143
max_n_spec_tokens: int = 5

colossalai/inference/core/engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def steps_spec_dec(self) -> List[Sequence]:
424424

425425
# 2. Prefill main model (Verifier) - fill past kv cache for main model
426426
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
427-
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
427+
next_tokens = self.request_handler.search_tokens(logits, batch, self.generation_config)
428428
# append new inputs to the batch, temporarily
429429
batch.append_batch_tokens(next_tokens)
430430
self.request_handler.allocate_batch_spec_dec(batch, 1)
@@ -472,7 +472,7 @@ def steps_spec_dec(self) -> List[Sequence]:
472472
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
473473
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
474474

475-
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
475+
next_tokens = self.request_handler.search_tokens(logits, batch, self.generation_config)
476476

477477
# 5. Compare and process the results
478478
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
@@ -738,7 +738,7 @@ def step(self) -> List[str]:
738738
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
739739
if self.inference_config.pad_input:
740740
logits = logits[:, -1, :]
741-
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
741+
next_tokens = self.request_handler.search_tokens(logits, batch, self.generation_config)
742742
self.request_handler.append_next_tokens(next_tokens)
743743
finished_sequences = self.request_handler.update()
744744

colossalai/inference/core/request_handler.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from colossalai.inference.config import InferenceConfig
99
from colossalai.inference.flash_decoding_utils import FDIntermTensors
1010
from colossalai.inference.kv_cache import KVCacheManager
11-
from colossalai.inference.logit_processors import logit_processor
11+
from colossalai.inference.logit_processors import logits_processor
1212
from colossalai.inference.sampler import *
1313
from colossalai.inference.struct import RequestStatus, Sequence
1414
from colossalai.logging import get_dist_logger
@@ -331,9 +331,19 @@ def check_unfinished_seqs(self) -> bool:
331331
def total_requests_in_batch_bucket(self) -> int:
332332
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
333333

334-
def search_tokens(self, generation_config: GenerationConfig, logits):
334+
def search_tokens(
335+
self,
336+
logits: torch.Tensor,
337+
batch_bucket: BatchBucket,
338+
generation_config: GenerationConfig,
339+
):
335340
"""
336341
Sample tokens for finished requests.
342+
343+
Args:
344+
input_ids (torch.Tensor): [num_token_ids] The flattened input tensor.
345+
logits (torch.Tensor): [num_seqs, vocab_size] The logits tensor.
346+
generation_config (GenerationConfig): The generation configuration.
337347
"""
338348

339349
# do logit processor
@@ -342,7 +352,18 @@ def search_tokens(self, generation_config: GenerationConfig, logits):
342352
config_dict = generation_config.to_dict()
343353
for type in ["temperature", "top_k", "top_p"]:
344354
if type in config_dict and config_dict[type] is not None:
345-
logits = logit_processor(type, logits, config_dict[type])
355+
logits = logits_processor(type, logits, config_dict[type])
356+
357+
forced_eos_token_id = config_dict.get("forced_eos_token_id", None)
358+
if forced_eos_token_id is not None:
359+
# sequence_lengths = batch_bucket.seq_lengths
360+
num_seqs = len(batch_bucket)
361+
seq_out_lengths, max_out_lengths = [0] * num_seqs, [0] * num_seqs
362+
max_out_lengths = [0] * num_seqs
363+
for i, seq in enumerate(batch_bucket.seqs_li):
364+
# retrieve the current output length and the maximum out length bound with each Sequence
365+
seq_out_lengths[i], max_out_lengths[i] = seq.output_len, seq.max_output_len
366+
logits_processor("forced_eos_token_id", logits, seq_out_lengths, max_out_lengths, forced_eos_token_id)
346367

347368
# calculate probs
348369
probs = torch.softmax(logits, dim=-1, dtype=torch.float)

colossalai/inference/logit_processors.py

Lines changed: 87 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,27 @@
1+
import logging
2+
from typing import List, Union
3+
14
import torch
25
import torch.nn.functional as F
36

4-
_LOGIT_PROCESSOR_MAP = {}
7+
_LOGITS_PROCESSOR_MAP = {}
58

69

7-
def register_logit_processor(process_type):
10+
def register_logits_processor(process_type):
811
"""
912
register flops computation function for operation.
1013
"""
1114

1215
def register(func):
13-
global _LOGIT_PROCESSOR_MAP
14-
_LOGIT_PROCESSOR_MAP[process_type] = func
16+
global _LOGITS_PROCESSOR_MAP
17+
_LOGITS_PROCESSOR_MAP[process_type] = func
1518
return func
1619

1720
return register
1821

1922

20-
@register_logit_processor("temperature")
21-
def temperature_logit_process(logits, temperature: float):
23+
@register_logits_processor("temperature")
24+
def temperature_logits_process(logits, temperature: float):
2225
"""
2326
apply temperature scaling.
2427
"""
@@ -32,8 +35,8 @@ def temperature_logit_process(logits, temperature: float):
3235
return logits if temperature == 1.0 else logits / temperature
3336

3437

35-
@register_logit_processor("top_k")
36-
def top_k_logit_processor(logits, top_k: int):
38+
@register_logits_processor("top_k")
39+
def top_k_logits_processor(logits, top_k: int):
3740
"""
3841
top_k logit processor
3942
"""
@@ -46,8 +49,8 @@ def top_k_logit_processor(logits, top_k: int):
4649
return logits
4750

4851

49-
@register_logit_processor("top_p")
50-
def top_p_logit_processor(logits, top_p: float):
52+
@register_logits_processor("top_p")
53+
def top_p_logits_processor(logits, top_p: float):
5154
"""
5255
top_p logit processor
5356
"""
@@ -68,24 +71,88 @@ def top_p_logit_processor(logits, top_p: float):
6871
return logits
6972

7073

71-
def logit_processor(processor: str, logits, attrs):
74+
@register_logits_processor("forced_bos_token_id")
75+
def forced_bos_token_processor(
76+
logits: torch.Tensor,
77+
sequence_lengths: Union[torch.Tensor, List[int]],
78+
max_out_lengths: Union[torch.Tensor, List[int]],
79+
bos_token_id: int,
80+
):
81+
# NOTE For now, optimizations for encoder-decoder models have not been supported yet
82+
# And this function will never be called in the current implementation.
83+
if isinstance(sequence_lengths, torch.Tensor):
84+
sequence_lengths = sequence_lengths.tolist()
85+
if isinstance(max_out_lengths, torch.Tensor):
86+
max_out_lengths = max_out_lengths.tolist()
87+
88+
select_indexes = []
89+
num_sequences = logits.shape[0]
90+
sequence_lengths = sequence_lengths[:num_sequences]
91+
max_out_lengths = max_out_lengths[:num_sequences]
92+
for i, sequence_length in enumerate(sequence_lengths):
93+
if sequence_length == 1:
94+
select_indexes.append(i)
95+
if select_indexes:
96+
logits[select_indexes, :] = -float("inf")
97+
logits[select_indexes, bos_token_id] = 0
98+
99+
return logits
100+
101+
102+
@register_logits_processor("forced_eos_token_id")
103+
def forced_eos_token_processor(
104+
logits: torch.Tensor,
105+
sequence_lengths: Union[torch.Tensor, List[int]],
106+
max_out_lengths: Union[torch.Tensor, List[int]],
107+
eos_token_id: Union[int, List[int]],
108+
):
109+
"""
110+
Enforces the specified token as the last generated token when the maximum output length
111+
is reached. Notice that the maximum output lengths for different sequences, even if they're
112+
in the same batch, can be different.
113+
114+
Args:
115+
logits(torch.Tensor): logits
116+
sequence_lengths(torch.Tensor): sequence lengths
117+
max_out_lengths(torch.Tensor): maximum output lengths for each sequence
118+
eos_token_id(Union[int, List[int]]): forced eos token id
119+
"""
120+
if isinstance(eos_token_id, int):
121+
eos_token_id = [eos_token_id]
122+
if isinstance(sequence_lengths, torch.Tensor):
123+
sequence_lengths = sequence_lengths.tolist()
124+
if isinstance(max_out_lengths, torch.Tensor):
125+
max_out_lengths = max_out_lengths.tolist()
126+
127+
select_indexes = []
128+
num_sequences = logits.shape[0]
129+
sequence_lengths = sequence_lengths[:num_sequences]
130+
max_out_lengths = max_out_lengths[:num_sequences]
131+
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_out_lengths)):
132+
if sequence_length == max_out_length - 1:
133+
select_indexes.append(i)
134+
if select_indexes:
135+
logits[select_indexes, :] = -float("inf")
136+
logits[select_indexes, eos_token_id] = 0
137+
138+
return logits
139+
140+
141+
def logits_processor(processor: str, logits, *args):
72142
"""
73143
do logit process for given logits.
74144
75145
Args:
76146
processor(str): the type of logit processor
77147
logits(torch.Tensor): input logits
78-
attrs(dict): attrs of the logit processor
79148
80149
Returns:
81150
logits after process
82151
"""
83-
if processor not in _LOGIT_PROCESSOR_MAP:
84-
return logits
152+
if processor not in _LOGITS_PROCESSOR_MAP:
153+
logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.")
85154
else:
86-
func = _LOGIT_PROCESSOR_MAP[processor]
87-
try:
88-
logits = func(logits, attrs)
89-
except Exception:
90-
return logits
91-
return logits
155+
func = _LOGITS_PROCESSOR_MAP[processor]
156+
logits = func(logits, *args)
157+
158+
return logits

0 commit comments

Comments
 (0)