99from colossalai .cluster import ProcessGroupMesh
1010from colossalai .inference .config import InferenceConfig
1111from colossalai .inference .modeling .policy import model_policy_map
12- from colossalai .inference .spec import Drafter
12+ from colossalai .inference .spec import Drafter , GlideInput
1313from colossalai .inference .struct import Sequence
1414from colossalai .logging import get_dist_logger
1515from colossalai .pipeline .stage_manager import PipelineStageManager
@@ -66,6 +66,7 @@ def __init__(
6666 self .use_spec_dec = False
6767 self .drafter_model = None
6868 self .drafter = None
69+ self .use_glide = False
6970 self .n_spec_tokens = self .inference_config .max_n_spec_tokens
7071
7172 if model_policy is None :
@@ -141,14 +142,21 @@ def _shardformer(
141142 shard_model , _ = shardformer .optimize (model , model_policy )
142143 return shard_model
143144
144- def enable_spec_dec (self , drafter_model : nn .Module = None , n_spec_tokens : int = None ) -> None :
145+ def enable_spec_dec (
146+ self ,
147+ drafter_model : nn .Module = None ,
148+ n_spec_tokens : int = None ,
149+ use_glide_drafter : bool = False ,
150+ ) -> None :
145151 """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
146152
147153 Args:
148154 drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
149155 If provided, the previous drafter and drafter model, if exist, will be overwritten.
150156 n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
151157 If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
158+ use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
159+ If True, the drafter model will be replaced by a glide model.
152160
153161 ```python
154162 ...
@@ -181,6 +189,22 @@ def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int =
181189 device = self .device ,
182190 dtype = self .dtype ,
183191 )
192+
193+ # check if the provided drafter model is compatible with GLIDE structure
194+ # when `use_glide_drafter` is set to True
195+ if (
196+ use_glide_drafter
197+ and hasattr (drafter_model , "model" )
198+ and hasattr (drafter_model .model , "layers" )
199+ and hasattr (drafter_model .model .layers [0 ], "cross_attn" )
200+ ):
201+ self .use_glide = use_glide_drafter
202+ elif use_glide_drafter :
203+ self .logger .warning (
204+ f"`use_glide_drafter` is provided as { use_glide_drafter } , "
205+ f"but the provided drafter model is not compatible with GLIDE structure."
206+ f"Falling back to use the default drafter model (non-GLIDE)."
207+ )
184208 self .request_handler .set_spec_dec_mode (self .n_spec_tokens )
185209 # using speculative decoding for subsequent generations
186210 self .use_spec_dec = True
@@ -190,6 +214,7 @@ def disable_spec_dec(self) -> None:
190214 self .request_handler .unset_spec_dec_mode ()
191215 # set back to the maximum number of tokens to speculate
192216 self .n_spec_tokens = self .inference_config .max_n_spec_tokens
217+ self .use_glide = False
193218 self .use_spec_dec = False
194219
195220 def clear_spec_dec (self ) -> None :
@@ -200,6 +225,7 @@ def clear_spec_dec(self) -> None:
200225 self .drafter_model = None
201226 self .drafter = None
202227 torch .cuda .empty_cache ()
228+ self .use_glide = False
203229 self .use_spec_dec = False
204230
205231 def steps_spec_dec (self ) -> List [Sequence ]:
@@ -216,6 +242,7 @@ def steps_spec_dec(self) -> List[Sequence]:
216242 input_ids = batch .get_1D_inputs () # bsz 1 for drafter model
217243
218244 # 1. Prefill small model (Drafter) - fill past kv cache for drafter model
245+ # NOTE For glide drafter models, we won't actually apply glide during prefill stage
219246 drafter_out = self .drafter .speculate (input_ids , 1 , None )
220247 next_token_ids_spec = drafter_out .next_tokens
221248 drafter_past_key_values = drafter_out .past_key_values
@@ -238,7 +265,21 @@ def steps_spec_dec(self) -> List[Sequence]:
238265 assert batch .current_batch_size == 1 , "Only support bsz 1 for speculative decoding for now."
239266
240267 # 3. Decoding - Drafter model speculates `n` tokens
241- drafter_out = self .drafter .speculate (input_ids , self .n_spec_tokens , drafter_past_key_values )
268+ glide_input = None
269+ if self .use_glide :
270+ glide_input = GlideInput (
271+ batch .get_block_table_tensor (),
272+ self .k_cahce [- 1 ], # use kv cahces of the last layer
273+ self .v_cache [- 1 ],
274+ batch .get_sequence_lengths (),
275+ )
276+
277+ drafter_out = self .drafter .speculate (
278+ input_ids ,
279+ self .n_spec_tokens ,
280+ drafter_past_key_values ,
281+ glide_input = glide_input ,
282+ )
242283 next_token_ids_spec = drafter_out .next_tokens
243284 drafter_past_key_values = drafter_out .past_key_values
244285 drafter_spec_length = drafter_out .speculated_length
@@ -251,6 +292,8 @@ def steps_spec_dec(self) -> List[Sequence]:
251292 already_allocated_kv_len = cur_length
252293
253294 # 4. Decoding - Main model verifies `n` tokens in parallel
295+ if drafter_spec_length < batch .num_tokens_to_verify :
296+ batch .set_use_spec_dec (num_tokens_to_verify = drafter_spec_length )
254297 logits = self .model (batch , self .k_cahce , self .v_cache )
255298 next_tokens = self .request_handler .search_tokens (self .generation_config , logits )
256299
@@ -260,13 +303,15 @@ def steps_spec_dec(self) -> List[Sequence]:
260303
261304 # revoke appended tokens for each Sequence in the current batch
262305 batch .revoke_batch_tokens (drafter_spec_length - n_matches ) # revoke drafted tokens
306+
263307 # append the last correct token generated by the main model
264308 self .request_handler .append_next_tokens (next_tokens [n_matches ].unsqueeze (0 ))
265309
266310 # trim past key values of the drafter model
267311 drafter_past_key_values = Drafter .trim_kv_cache (
268312 drafter_past_key_values , drafter_spec_length - n_matches - 1
269313 )
314+
270315 # prepare inputs for the next round of speculation
271316 n = 1 if n_matches < drafter_spec_length else 2
272317 input_ids = batch .get_1D_inputs_spec_dec (n )
@@ -276,6 +321,11 @@ def steps_spec_dec(self) -> List[Sequence]:
276321 if len (finished_sequences ) > 0 :
277322 break
278323
324+ # Reset back the number of speculated tokens of the batch,
325+ # this is used to handle the last round of speculation, in which case the number of speculated tokens
326+ # by the drafter is less than the number of speculated tokens set to the engine.
327+ batch .set_use_spec_dec (num_tokens_to_verify = self .n_spec_tokens )
328+
279329 return finished_sequences
280330
281331 def generate (
0 commit comments