Skip to content
4 changes: 2 additions & 2 deletions colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ We offer 3 main sampling strategies now (i.e. `greedy sample`, `multinomial samp

| Model | KV Cache | Paged Attention | Kernels | Tensor Parallelism | Speculative Decoding |
| - | - | - | - | - | - |
| Llama | ✅ | ✅ | ✅ | 🔜 | 🔜 |
| Llama | ✅ | ✅ | ✅ | 🔜 | |


Notations:
Expand All @@ -148,7 +148,7 @@ Notations:
- [x] High-Performance Kernels
- [x] Llama Modelling
- [x] User Documentation
- [ ] Speculative Decoding
- [x] Speculative Decoding
- [ ] Tensor Parallelism
- [ ] Beam Search
- [ ] Early stopping
Expand Down
62 changes: 60 additions & 2 deletions colossalai/inference/batch_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __init__(
self.device = device or get_current_device()
self.dtype = dtype

self._use_spec_dec = False
self._num_tokens_to_verify = None

self._current_batch_size = 0
self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
Expand Down Expand Up @@ -88,6 +91,27 @@ def is_compact(self):
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
)

@property
def use_spec_dec(self) -> bool:
return self._use_spec_dec

@property
def num_tokens_to_verify(self) -> int:
return self._num_tokens_to_verify

def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
and let the main model verifies tokens in parallel.
"""
self._use_spec_dec = True
self._num_tokens_to_verify = num_tokens_to_verify

def reset_use_spec_dec(self) -> None:
"""Reset the usage of speculative decoding for the batch bucket"""
self._use_spec_dec = False
self._num_tokens_to_verify = None

def _make_compact(self) -> None:
# Clean and Compress the batch based on its sequences dict.
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
Expand Down Expand Up @@ -347,6 +371,23 @@ def append_batch_tokens(self, tokens: torch.Tensor) -> None:
seq.check_finish()
self._sequence_lengths[: self.current_batch_size] += 1

def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None:
"""Revoke the last n output tokens of the sequences in the batch

Args:
n_tokens (int): The number of output tokens to revoke from each sequence.
It does not count in the context tokens (input tokens).
n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1.
For now, speculative decoding only supports batch size 1.
"""
if n_tokens >= 1:
seqs_iter = iter(self._sequences_dict.items())
for _ in range(n_seqs):
seq_id, seq = next(seqs_iter)
assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n_tokens]
self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens

def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
"""Clear all the sequences in the batch.

Expand Down Expand Up @@ -401,6 +442,21 @@ def is_prompts(self) -> bool:
return True
return False

def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor:
# Used for main model verification in **Decoding Stage**
# `n` is the number of tokens to be verified,
# and so that prepare the last `n` tokens of each sequence as the inputs
assert len(self._sequences_dict) > 0, "No sequence in the batch"
assert all(
seq.output_len >= n for seq in self._sequences_dict.values()
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.output_token_id[-n:])
return torch.tensor(out_li, dtype=torch.long, device=self.device)

# For compatibility
def get_1D_inputs(self) -> torch.Tensor:
assert len(self._sequences_dict) > 0, "No sequence in the batch"
Expand All @@ -411,15 +467,17 @@ def get_1D_inputs(self) -> torch.Tensor:
seq.output_len == 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
out_li = []
num_tokens = torch.sum(self._sequence_lengths)
out = torch.empty([num_tokens], dtype=torch.long)
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.input_token_id)
return torch.tensor(out_li, dtype=torch.long, device=self.device)
else:
# Assume decoding stage
if self.use_spec_dec:
# For Speculative Decoding
# the number of tokens to be verified in parallel plus the correct token in the last step
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
assert all(
seq.output_len > 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
Expand Down
25 changes: 23 additions & 2 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
}


Expand All @@ -46,6 +46,8 @@ class InputMetaData:
head_dim (int, optional): Head dimension. Defaults to 32.
high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False.
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
use_spec_dec (bool): Indicate whether to use speculative decoding.
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
"""

block_tables: torch.Tensor = None
Expand All @@ -59,9 +61,22 @@ class InputMetaData:
head_dim: int = 32
high_precision: bool = False
dtype: torch.dtype = torch.float32
use_spec_dec: bool = False
num_tokens_to_verify: int = 0

def __repr__(self) -> str:
return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})"
return (
f"InputMetaData(block_tables={self.block_tables}, "
f"sequence_lengths={self.sequence_lengths}, "
f"fd_inter_tensor={self.fd_inter_tensor}, "
f"batch_size={self.batch_size}, "
f"is_prompts={self.is_prompts}, "
f"use_cuda_kernel={self.use_cuda_kernel}, "
f"use_cuda_graph={self.use_cuda_graph}, "
f"kv_seq_len={self.kv_seq_len}, "
f"use_spec_dec={self.use_spec_dec}, "
f"num_tokens_to_verify={self.num_tokens_to_verify})"
)


@dataclass
Expand All @@ -84,6 +99,8 @@ class InferenceConfig:
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1.
pp_size (int): Pipeline parallel size, defaults to 1.
Expand Down Expand Up @@ -118,6 +135,10 @@ class InferenceConfig:
top_p: Optional[float] = None
min_p: Optional[float] = None

# speculative decoding configs
max_n_spec_tokens: int = 5
glimpse_large_kv: bool = False

# paged attention configs
block_size: int = 16

Expand Down
Loading