Skip to content

Commit 25928d8

Browse files
[Inference/Spec-Dec] Merge pull request #5565 from hpcaitech/feat/speculative-decoding
Add Speculative Decoding and GLIDE Spec-Dec
2 parents d56c963 + f8598e3 commit 25928d8

File tree

22 files changed

+1688
-192
lines changed

22 files changed

+1688
-192
lines changed

colossalai/inference/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ We offer 3 main sampling strategies now (i.e. `greedy sample`, `multinomial samp
133133

134134
| Model | KV Cache | Paged Attention | Kernels | Tensor Parallelism | Speculative Decoding |
135135
| - | - | - | - | - | - |
136-
| Llama |||| 🔜 | 🔜 |
136+
| Llama |||| 🔜 | |
137137

138138

139139
Notations:
@@ -148,7 +148,7 @@ Notations:
148148
- [x] High-Performance Kernels
149149
- [x] Llama Modelling
150150
- [x] User Documentation
151-
- [ ] Speculative Decoding
151+
- [x] Speculative Decoding
152152
- [ ] Tensor Parallelism
153153
- [ ] Beam Search
154154
- [ ] Early stopping

colossalai/inference/batch_bucket.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def __init__(
4242
self.device = device or get_current_device()
4343
self.dtype = dtype
4444

45+
self._use_spec_dec = False
46+
self._num_tokens_to_verify = None
47+
4548
self._current_batch_size = 0
4649
self._sequences_dict = dict()
4750
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
@@ -88,6 +91,27 @@ def is_compact(self):
8891
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
8992
)
9093

94+
@property
95+
def use_spec_dec(self) -> bool:
96+
return self._use_spec_dec
97+
98+
@property
99+
def num_tokens_to_verify(self) -> int:
100+
return self._num_tokens_to_verify
101+
102+
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
103+
"""Set batch bucket to use speculatvie decoding.
104+
This will notify the adjust the lengths of inputs during modeling,
105+
and let the main model verifies tokens in parallel.
106+
"""
107+
self._use_spec_dec = True
108+
self._num_tokens_to_verify = num_tokens_to_verify
109+
110+
def reset_use_spec_dec(self) -> None:
111+
"""Reset the usage of speculative decoding for the batch bucket"""
112+
self._use_spec_dec = False
113+
self._num_tokens_to_verify = None
114+
91115
def _make_compact(self) -> None:
92116
# Clean and Compress the batch based on its sequences dict.
93117
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
@@ -347,6 +371,23 @@ def append_batch_tokens(self, tokens: torch.Tensor) -> None:
347371
seq.check_finish()
348372
self._sequence_lengths[: self.current_batch_size] += 1
349373

374+
def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
375+
"""Revoke the last n output tokens of the sequences in the batch
376+
377+
Args:
378+
n_tokens (int): The number of output tokens to revoke from each sequence.
379+
It does not count in the context tokens (input tokens).
380+
n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1.
381+
For now, speculative decoding only supports batch size 1.
382+
"""
383+
if n_tokens >= 1:
384+
seqs_iter = iter(self._sequences_dict.items())
385+
for _ in range(n_seqs):
386+
seq_id, seq = next(seqs_iter)
387+
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
388+
seq.output_token_id = seq.output_token_id[:-n_tokens]
389+
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens
390+
350391
def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
351392
"""Clear all the sequences in the batch.
352393
@@ -401,6 +442,21 @@ def is_prompts(self) -> bool:
401442
return True
402443
return False
403444

445+
def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor:
446+
# Used for main model verification in **Decoding Stage**
447+
# `n` is the number of tokens to be verified,
448+
# and so that prepare the last `n` tokens of each sequence as the inputs
449+
assert len(self._sequences_dict) > 0, "No sequence in the batch"
450+
assert all(
451+
seq.output_len >= n for seq in self._sequences_dict.values()
452+
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
453+
out_li = []
454+
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
455+
for seq_id in seq_ids:
456+
seq: Sequence = self._sequences_dict[seq_id]
457+
out_li.extend(seq.output_token_id[-n:])
458+
return torch.tensor(out_li, dtype=torch.long, device=self.device)
459+
404460
# For compatibility
405461
def get_1D_inputs(self) -> torch.Tensor:
406462
assert len(self._sequences_dict) > 0, "No sequence in the batch"
@@ -411,15 +467,17 @@ def get_1D_inputs(self) -> torch.Tensor:
411467
seq.output_len == 0 for seq in self._sequences_dict.values()
412468
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
413469
out_li = []
414-
num_tokens = torch.sum(self._sequence_lengths)
415-
out = torch.empty([num_tokens], dtype=torch.long)
416470
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
417471
for seq_id in seq_ids:
418472
seq: Sequence = self._sequences_dict[seq_id]
419473
out_li.extend(seq.input_token_id)
420474
return torch.tensor(out_li, dtype=torch.long, device=self.device)
421475
else:
422476
# Assume decoding stage
477+
if self.use_spec_dec:
478+
# For Speculative Decoding
479+
# the number of tokens to be verified in parallel plus the correct token in the last step
480+
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
423481
assert all(
424482
seq.output_len > 0 for seq in self._sequences_dict.values()
425483
), "Sequence stage (Prefill/Decoding) must be the same in the batch"

colossalai/inference/config.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
_DEFAULT_PROMPT_TEMPLATES = {
2828
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
29-
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
29+
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
3030
}
3131

3232

@@ -46,6 +46,8 @@ class InputMetaData:
4646
head_dim (int, optional): Head dimension. Defaults to 32.
4747
high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False.
4848
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
49+
use_spec_dec (bool): Indicate whether to use speculative decoding.
50+
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
4951
"""
5052

5153
block_tables: torch.Tensor = None
@@ -59,9 +61,22 @@ class InputMetaData:
5961
head_dim: int = 32
6062
high_precision: bool = False
6163
dtype: torch.dtype = torch.float32
64+
use_spec_dec: bool = False
65+
num_tokens_to_verify: int = 0
6266

6367
def __repr__(self) -> str:
64-
return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})"
68+
return (
69+
f"InputMetaData(block_tables={self.block_tables}, "
70+
f"sequence_lengths={self.sequence_lengths}, "
71+
f"fd_inter_tensor={self.fd_inter_tensor}, "
72+
f"batch_size={self.batch_size}, "
73+
f"is_prompts={self.is_prompts}, "
74+
f"use_cuda_kernel={self.use_cuda_kernel}, "
75+
f"use_cuda_graph={self.use_cuda_graph}, "
76+
f"kv_seq_len={self.kv_seq_len}, "
77+
f"use_spec_dec={self.use_spec_dec}, "
78+
f"num_tokens_to_verify={self.num_tokens_to_verify})"
79+
)
6580

6681

6782
@dataclass
@@ -84,6 +99,8 @@ class InferenceConfig:
8499
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
85100
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
86101
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
102+
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
103+
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
87104
block_size (int): The number of blocks in a logical block, defaults to 16.
88105
tp_size (int): Tensor parallel size, defaults to 1.
89106
pp_size (int): Pipeline parallel size, defaults to 1.
@@ -118,6 +135,10 @@ class InferenceConfig:
118135
top_p: Optional[float] = None
119136
min_p: Optional[float] = None
120137

138+
# speculative decoding configs
139+
max_n_spec_tokens: int = 5
140+
glimpse_large_kv: bool = False
141+
121142
# paged attention configs
122143
block_size: int = 16
123144

0 commit comments

Comments
 (0)