Skip to content

Commit

Permalink
Fix prefix caching + speculative decoding (#2711)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Nov 4, 2024
1 parent a5593ba commit aadc9cb
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,11 +887,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
fsm_grammar_states=fsm_grammar_states,
)

speculative_ids = (
torch.cat([b.speculative_ids for b in batches], dim=0)
if batches[0].speculative_ids is not None
else None
)
# We skip computing the speculative_ids when the batch size is too large, so
# we must check that all batches have them, otherwise they must be discarded
if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
else:
speculative_ids = None

if adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
Expand Down Expand Up @@ -1724,7 +1725,13 @@ def forward(
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)

# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
# allocated
slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
slots = batch.slots[slot_indices]

input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
Expand Down

0 comments on commit aadc9cb

Please sign in to comment.