Skip to content

Commit 23c03db

Browse files
[Inference/SpecDec] Support GLIDE Drafter Model (#5455)
* add glide-llama policy and modeling * update glide modeling, compitable with transformers 4.36.2 * revise glide llama modeling/usage * fix issues of glimpsing large kv * revise the way re-loading params for glide drafter * fix drafter and engine tests * enable convert to glide strict=False * revise glide llama modeling * revise vicuna prompt template * revise drafter and tests * apply usage of glide model in engine
1 parent b8016a9 commit 23c03db

File tree

10 files changed

+718
-78
lines changed

10 files changed

+718
-78
lines changed

colossalai/inference/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
_DEFAULT_PROMPT_TEMPLATES = {
2828
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
29-
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
29+
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
3030
}
3131

3232

colossalai/inference/core/engine.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from colossalai.cluster import ProcessGroupMesh
1010
from colossalai.inference.config import InferenceConfig
1111
from colossalai.inference.modeling.policy import model_policy_map
12-
from colossalai.inference.spec import Drafter
12+
from colossalai.inference.spec import Drafter, GlideInput
1313
from colossalai.inference.struct import Sequence
1414
from colossalai.logging import get_dist_logger
1515
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -66,6 +66,7 @@ def __init__(
6666
self.use_spec_dec = False
6767
self.drafter_model = None
6868
self.drafter = None
69+
self.use_glide = False
6970
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
7071

7172
if model_policy is None:
@@ -141,14 +142,21 @@ def _shardformer(
141142
shard_model, _ = shardformer.optimize(model, model_policy)
142143
return shard_model
143144

144-
def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None:
145+
def enable_spec_dec(
146+
self,
147+
drafter_model: nn.Module = None,
148+
n_spec_tokens: int = None,
149+
use_glide_drafter: bool = False,
150+
) -> None:
145151
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
146152
147153
Args:
148154
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
149155
If provided, the previous drafter and drafter model, if exist, will be overwritten.
150156
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
151157
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
158+
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
159+
If True, the drafter model will be replaced by a glide model.
152160
153161
```python
154162
...
@@ -181,6 +189,22 @@ def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int =
181189
device=self.device,
182190
dtype=self.dtype,
183191
)
192+
193+
# check if the provided drafter model is compatible with GLIDE structure
194+
# when `use_glide_drafter` is set to True
195+
if (
196+
use_glide_drafter
197+
and hasattr(drafter_model, "model")
198+
and hasattr(drafter_model.model, "layers")
199+
and hasattr(drafter_model.model.layers[0], "cross_attn")
200+
):
201+
self.use_glide = use_glide_drafter
202+
elif use_glide_drafter:
203+
self.logger.warning(
204+
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
205+
f"but the provided drafter model is not compatible with GLIDE structure."
206+
f"Falling back to use the default drafter model (non-GLIDE)."
207+
)
184208
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
185209
# using speculative decoding for subsequent generations
186210
self.use_spec_dec = True
@@ -190,6 +214,7 @@ def disable_spec_dec(self) -> None:
190214
self.request_handler.unset_spec_dec_mode()
191215
# set back to the maximum number of tokens to speculate
192216
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
217+
self.use_glide = False
193218
self.use_spec_dec = False
194219

195220
def clear_spec_dec(self) -> None:
@@ -200,6 +225,7 @@ def clear_spec_dec(self) -> None:
200225
self.drafter_model = None
201226
self.drafter = None
202227
torch.cuda.empty_cache()
228+
self.use_glide = False
203229
self.use_spec_dec = False
204230

205231
def steps_spec_dec(self) -> List[Sequence]:
@@ -216,6 +242,7 @@ def steps_spec_dec(self) -> List[Sequence]:
216242
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
217243

218244
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
245+
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
219246
drafter_out = self.drafter.speculate(input_ids, 1, None)
220247
next_token_ids_spec = drafter_out.next_tokens
221248
drafter_past_key_values = drafter_out.past_key_values
@@ -238,7 +265,21 @@ def steps_spec_dec(self) -> List[Sequence]:
238265
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
239266

240267
# 3. Decoding - Drafter model speculates `n` tokens
241-
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
268+
glide_input = None
269+
if self.use_glide:
270+
glide_input = GlideInput(
271+
batch.get_block_table_tensor(),
272+
self.k_cahce[-1], # use kv cahces of the last layer
273+
self.v_cache[-1],
274+
batch.get_sequence_lengths(),
275+
)
276+
277+
drafter_out = self.drafter.speculate(
278+
input_ids,
279+
self.n_spec_tokens,
280+
drafter_past_key_values,
281+
glide_input=glide_input,
282+
)
242283
next_token_ids_spec = drafter_out.next_tokens
243284
drafter_past_key_values = drafter_out.past_key_values
244285
drafter_spec_length = drafter_out.speculated_length
@@ -251,6 +292,8 @@ def steps_spec_dec(self) -> List[Sequence]:
251292
already_allocated_kv_len = cur_length
252293

253294
# 4. Decoding - Main model verifies `n` tokens in parallel
295+
if drafter_spec_length < batch.num_tokens_to_verify:
296+
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
254297
logits = self.model(batch, self.k_cahce, self.v_cache)
255298
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
256299

@@ -260,13 +303,15 @@ def steps_spec_dec(self) -> List[Sequence]:
260303

261304
# revoke appended tokens for each Sequence in the current batch
262305
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
306+
263307
# append the last correct token generated by the main model
264308
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
265309

266310
# trim past key values of the drafter model
267311
drafter_past_key_values = Drafter.trim_kv_cache(
268312
drafter_past_key_values, drafter_spec_length - n_matches - 1
269313
)
314+
270315
# prepare inputs for the next round of speculation
271316
n = 1 if n_matches < drafter_spec_length else 2
272317
input_ids = batch.get_1D_inputs_spec_dec(n)
@@ -276,6 +321,11 @@ def steps_spec_dec(self) -> List[Sequence]:
276321
if len(finished_sequences) > 0:
277322
break
278323

324+
# Reset back the number of speculated tokens of the batch,
325+
# this is used to handle the last round of speculation, in which case the number of speculated tokens
326+
# by the drafter is less than the number of speculated tokens set to the engine.
327+
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
328+
279329
return finished_sequences
280330

281331
def generate(

0 commit comments

Comments
 (0)