Skip to content

Commit e6496dd

Browse files
authored
[Inference] Optimize request handler of llama (#5512)
* optimize request_handler * fix ways of writing
1 parent 6251d68 commit e6496dd

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

colossalai/inference/core/request_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,8 @@ def search_tokens(self, generation_config: GenerationConfig, logits):
298298
"""
299299
# do logit processor
300300
# NOTE: need to decide the granularity to process logits (sequence or batch)
301+
config_dict = generation_config.to_dict()
301302
for type in ["top_k", "top_p", "min_p"]:
302-
config_dict = generation_config.to_dict()
303303
if type in config_dict and config_dict[type] is not None:
304304
logits = logit_processor(type, logits, config_dict[type])
305305

colossalai/inference/logit_processors.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,23 @@ def top_p_logit_processor(logits, top_p: float):
3636
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
3737

3838
sorted_indices_to_remove = cumulative_probs > top_p
39-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
39+
40+
sorted_indices_to_remove = torch.roll(sorted_indices_to_remove, 1, -1)
4041
sorted_indices_to_remove[..., 0] = 0
4142

4243
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
4344
logits[indices_to_remove] = -float("inf")
4445
return logits
4546

46-
def logit_processor(processor:str, logits , attrs):
47+
48+
def logit_processor(processor: str, logits, attrs):
4749
"""
4850
do logit process for given logits.
4951
5052
Args:
51-
processor(str): the type of logit processor
53+
processor(str): the type of logit processor
5254
logits(torch.Tensor): input logits
53-
attrs(dict): attrs of the logit processor
55+
attrs(dict): attrs of the logit processor
5456
5557
Returns:
5658
logits after process
@@ -61,6 +63,6 @@ def logit_processor(processor:str, logits , attrs):
6163
func = _LOGIT_PROCESSOR_MAP[processor]
6264
try:
6365
logits = func(logits, attrs)
64-
except Exception as e:
66+
except Exception:
6567
return logits
66-
return logits
68+
return logits

0 commit comments

Comments
 (0)