Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prefix caching + speculative decoding #2711

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,11 +887,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
fsm_grammar_states=fsm_grammar_states,
)

speculative_ids = (
torch.cat([b.speculative_ids for b in batches], dim=0)
if batches[0].speculative_ids is not None
else None
)
# We skip computing the speculative_ids when the batch size is too large, so
# we must check that all batches have them, otherwise they must be discarded
if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
else:
speculative_ids = None

if adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
Expand Down Expand Up @@ -1718,7 +1719,13 @@ def forward(
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)

# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
# allocated
slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
slots = batch.slots[slot_indices]

input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
Expand Down
Loading