Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
}


Expand Down
56 changes: 53 additions & 3 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import InferenceConfig
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.spec import Drafter
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens

if model_policy is None:
Expand Down Expand Up @@ -141,14 +142,21 @@ def _shardformer(
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model

def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None:
def enable_spec_dec(
self,
drafter_model: nn.Module = None,
n_spec_tokens: int = None,
use_glide_drafter: bool = False,
) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.

Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
If True, the drafter model will be replaced by a glide model.

```python
...
Expand Down Expand Up @@ -181,6 +189,22 @@ def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int =
device=self.device,
dtype=self.dtype,
)

# check if the provided drafter model is compatible with GLIDE structure
# when `use_glide_drafter` is set to True
if (
use_glide_drafter
and hasattr(drafter_model, "model")
and hasattr(drafter_model.model, "layers")
and hasattr(drafter_model.model.layers[0], "cross_attn")
):
self.use_glide = use_glide_drafter
elif use_glide_drafter:
self.logger.warning(
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
f"but the provided drafter model is not compatible with GLIDE structure."
f"Falling back to use the default drafter model (non-GLIDE)."
)
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
Expand All @@ -190,6 +214,7 @@ def disable_spec_dec(self) -> None:
self.request_handler.unset_spec_dec_mode()
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_glide = False
self.use_spec_dec = False

def clear_spec_dec(self) -> None:
Expand All @@ -200,6 +225,7 @@ def clear_spec_dec(self) -> None:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_glide = False
self.use_spec_dec = False

def steps_spec_dec(self) -> List[Sequence]:
Expand All @@ -216,6 +242,7 @@ def steps_spec_dec(self) -> List[Sequence]:
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model

# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
drafter_out = self.drafter.speculate(input_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
Expand All @@ -238,7 +265,21 @@ def steps_spec_dec(self) -> List[Sequence]:
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."

# 3. Decoding - Drafter model speculates `n` tokens
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
glide_input = None
if self.use_glide:
glide_input = GlideInput(
batch.get_block_table_tensor(),
self.k_cahce[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
)

drafter_out = self.drafter.speculate(
input_ids,
self.n_spec_tokens,
drafter_past_key_values,
glide_input=glide_input,
)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
drafter_spec_length = drafter_out.speculated_length
Expand All @@ -251,6 +292,8 @@ def steps_spec_dec(self) -> List[Sequence]:
already_allocated_kv_len = cur_length

# 4. Decoding - Main model verifies `n` tokens in parallel
if drafter_spec_length < batch.num_tokens_to_verify:
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
logits = self.model(batch, self.k_cahce, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)

Expand All @@ -260,13 +303,15 @@ def steps_spec_dec(self) -> List[Sequence]:

# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens

# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))

# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(
drafter_past_key_values, drafter_spec_length - n_matches - 1
)

# prepare inputs for the next round of speculation
n = 1 if n_matches < drafter_spec_length else 2
input_ids = batch.get_1D_inputs_spec_dec(n)
Expand All @@ -276,6 +321,11 @@ def steps_spec_dec(self) -> List[Sequence]:
if len(finished_sequences) > 0:
break

# Reset back the number of speculated tokens of the batch,
# this is used to handle the last round of speculation, in which case the number of speculated tokens
# by the drafter is less than the number of speculated tokens set to the engine.
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)

return finished_sequences

def generate(
Expand Down
Loading