Skip to content

Commit e60d430

Browse files
committed
[Fix] resolve conflicts of rebasing feat/speculative-decoding (#5557)
- resolve conflicts of rebasing feat/speculative-decoding
1 parent e1acb58 commit e60d430

File tree

6 files changed

+47
-35
lines changed

6 files changed

+47
-35
lines changed

colossalai/inference/batch_bucket.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def use_spec_dec(self) -> bool:
9797

9898
@property
9999
def num_tokens_to_verify(self) -> int:
100-
assert self.use_spec_dec and self._num_tokens_to_verify is not None
101100
return self._num_tokens_to_verify
102101

103102
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:

colossalai/inference/config.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class InputMetaData:
4646
head_dim (int, optional): Head dimension. Defaults to 32.
4747
high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False.
4848
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
49+
use_spec_dec (bool): Indicate whether to use speculative decoding.
50+
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
4951
"""
5052

5153
block_tables: torch.Tensor = None
@@ -59,9 +61,22 @@ class InputMetaData:
5961
head_dim: int = 32
6062
high_precision: bool = False
6163
dtype: torch.dtype = torch.float32
64+
use_spec_dec: bool = False
65+
num_tokens_to_verify: int = 0
6266

6367
def __repr__(self) -> str:
64-
return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})"
68+
return (
69+
f"InputMetaData(block_tables={self.block_tables}, "
70+
f"sequence_lengths={self.sequence_lengths}, "
71+
f"fd_inter_tensor={self.fd_inter_tensor}, "
72+
f"batch_size={self.batch_size}, "
73+
f"is_prompts={self.is_prompts}, "
74+
f"use_cuda_kernel={self.use_cuda_kernel}, "
75+
f"use_cuda_graph={self.use_cuda_graph}, "
76+
f"kv_seq_len={self.kv_seq_len}, "
77+
f"use_spec_dec={self.use_spec_dec}, "
78+
f"num_tokens_to_verify={self.num_tokens_to_verify})"
79+
)
6580

6681

6782
@dataclass

colossalai/inference/core/engine.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -325,24 +325,29 @@ def steps_spec_dec(self) -> List[Sequence]:
325325
List[Sequence]: finished sequences generated by one step.
326326
"""
327327
batch = self.request_handler.schedule() # prefill batch
328-
329328
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
330-
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
329+
330+
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
331+
332+
if input_meta_data.use_cuda_graph:
333+
model_executable = self.graph_runners[input_meta_data.batch_size]
334+
else:
335+
model_executable = self.model
331336

332337
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
333338
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
334-
drafter_out = self.drafter.speculate(input_ids, 1, None)
339+
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
335340
next_token_ids_spec = drafter_out.next_tokens
336341
drafter_past_key_values = drafter_out.past_key_values
337342

338343
# 2. Prefill main model (Verifier) - fill past kv cache for main model
339-
logits = self.model(batch, self.k_cahce, self.v_cache)
344+
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
340345
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
341346
# append new inputs to the batch, temporarily
342347
batch.append_batch_tokens(next_tokens)
343348
self.request_handler.allocate_batch_spec_dec(batch, 1)
344349
already_allocated_kv_len = batch.seq_lengths[0].item()
345-
input_ids = batch.get_1D_inputs_spec_dec(1)
350+
input_token_ids = batch.get_1D_inputs_spec_dec(1)
346351

347352
finished_sequences = self.request_handler.update()
348353

@@ -357,13 +362,13 @@ def steps_spec_dec(self) -> List[Sequence]:
357362
if self.use_glide:
358363
glide_input = GlideInput(
359364
batch.get_block_table_tensor(),
360-
self.k_cahce[-1], # use kv cahces of the last layer
365+
self.k_cache[-1], # use kv cahces of the last layer
361366
self.v_cache[-1],
362367
batch.get_sequence_lengths(),
363368
)
364369

365370
drafter_out = self.drafter.speculate(
366-
input_ids,
371+
input_token_ids,
367372
self.n_spec_tokens,
368373
drafter_past_key_values,
369374
glide_input=glide_input,
@@ -382,7 +387,9 @@ def steps_spec_dec(self) -> List[Sequence]:
382387
# 4. Decoding - Main model verifies `n` tokens in parallel
383388
if drafter_spec_length < batch.num_tokens_to_verify:
384389
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
385-
logits = self.model(batch, self.k_cahce, self.v_cache)
390+
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
391+
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
392+
386393
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
387394

388395
# 5. Compare and process the results
@@ -402,7 +409,7 @@ def steps_spec_dec(self) -> List[Sequence]:
402409

403410
# prepare inputs for the next round of speculation
404411
n = 1 if n_matches < drafter_spec_length else 2
405-
input_ids = batch.get_1D_inputs_spec_dec(n)
412+
input_token_ids = batch.get_1D_inputs_spec_dec(n)
406413

407414
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
408415
finished_sequences = self.request_handler.update()
@@ -564,18 +571,19 @@ def add_request(
564571

565572
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
566573
input_ids = batch.get_1D_inputs()
567-
568574
sequence_lengths = batch.get_sequence_lengths()
575+
569576
if batch.is_prompts:
570-
output_tensor = torch.zeros(
571-
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim),
572-
dtype=batch.dtype,
573-
device=batch.device,
574-
)
577+
n_tokens = sequence_lengths.sum().item()
575578
else:
576-
output_tensor = torch.zeros(
577-
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
578-
)
579+
n_tokens = batch.current_batch_size
580+
if batch.use_spec_dec:
581+
n_tokens = batch.num_tokens_to_verify + 1
582+
assert n_tokens == input_ids.size(0)
583+
n_tokens = n_tokens * batch.current_batch_size
584+
output_tensor = torch.zeros(
585+
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
586+
)
579587

580588
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
581589
use_cuda_graph = False
@@ -594,6 +602,8 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,
594602
kv_seq_len=sequence_lengths.max().item(),
595603
head_dim=batch.head_dim,
596604
dtype=batch.dtype,
605+
use_spec_dec=batch.use_spec_dec,
606+
num_tokens_to_verify=batch.num_tokens_to_verify,
597607
)
598608

599609
return input_ids, output_tensor, input_meta_data

colossalai/inference/modeling/models/nopadding_llama.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,11 @@ def llama_model_forward(
109109
# For speculative-decoding Prefill and Verifying Stage
110110
if inputmetadata.is_prompts:
111111
# output tensor shape is the same as normal Prefill Stage
112-
o_tensor_size = (sequence_lengths.sum().item(), inputmetadata.num_heads * inputmetadata.head_dim)
113112
rotary_indexes = [torch.arange(0, length) for length in sequence_lengths]
114113
else:
115114
# the number of tokens to be verified in parallel plus the correct token in the last step
116115
n_tokens = inputmetadata.num_tokens_to_verify + 1
117116
assert n_tokens == hidden_states.size(0)
118-
o_tensor_size = (batch_size * n_tokens, inputmetadata.num_heads * inputmetadata.head_dim)
119117
rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths]
120118
rotary_indexes = torch.cat(rotary_indexes, dim=-1)
121119
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
@@ -135,15 +133,6 @@ def llama_model_forward(
135133
else:
136134
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
137135

138-
# TODO (yuanheng-zhao): revise the logic here
139-
# if batch.is_prompts:
140-
# output_tensor = torch.zeros(
141-
# (sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
142-
# )
143-
# else:
144-
# output_tensor = torch.zeros(
145-
# (batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
146-
# )
147136
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
148137

149138
norm_output = torch.empty_like(hidden_states)
@@ -239,7 +228,6 @@ def llama_decoder_layer_forward(
239228
sequence_lengths=sequence_lengths,
240229
cos_sin=cos_sin,
241230
fd_inter_tensor=fd_inter_tensor,
242-
is_prompts=is_prompts,
243231
kv_seq_len=kv_seq_len,
244232
output_tensor=output_tensor,
245233
sm_scale=sm_scale,

tests/test_infer/test_ops/triton/test_decoding_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,5 +138,6 @@ def test_flash_decoding(
138138
assert out_torch.shape == out_triton.shape
139139
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
140140

141+
141142
if __name__ == "__main__":
142143
test_flash_decoding(16, 32, 32, 16, 1, True)

tests/test_infer/test_ops/triton/test_kvcache_copy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch
33
from packaging import version
44

5-
from colossalai.inference.modeling.layers.attention import copy_to_cache
65
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
76
from colossalai.utils import get_current_device
87
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
@@ -28,8 +27,8 @@ def prepare_data(
2827
max_num_blocks_per_seq,
2928
same_context_len,
3029
max_seq_len,
31-
n,
32-
device,
30+
n=1,
31+
device="cuda",
3332
dtype=torch.float16,
3433
):
3534
assert max_seq_len > n, "max_seq_len must be greater than n"

0 commit comments

Comments
 (0)