Skip to content

Commit

Permalink
Fixup:
Browse files Browse the repository at this point in the history
- Remove logs
- Disable VLMs (they do not work)
- Disable prefix caching when user wants prefill logprobs.
  • Loading branch information
Narsil committed Aug 16, 2024
1 parent 72f6b2c commit 043c1d1
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 34 deletions.
13 changes: 9 additions & 4 deletions backends/v3/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,15 @@ impl State {
+ self.speculate
- 1;

match block_allocator
.allocate(tokens, entry.request.input_ids.clone())
.await
{
// If users wants the prefill logprobs, we cannot reuse the cache.
// So no input_ids for the radix tree.
let input_ids = if entry.request.decoder_input_details {
None
} else {
entry.request.input_ids.clone()
};

match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
Expand Down
2 changes: 0 additions & 2 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ def from_tokenized(
block_tables_tensor = block_tables_tensor.to(device)
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)

log_master(logger.info, f"Block tables {block_tables}")
return cls(
batch_id=pb.id,
requests=pb.requests,
Expand Down Expand Up @@ -1915,7 +1914,6 @@ def _forward_context(
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)

if cu_seqlen_prefill is not None:
log_master(logger.info, f"Prefix lens {prefix_lens}")
return use_prefill_with_paged_kv_state(
state=(
state if state is not None else self.prefill_with_paged_kv_state
Expand Down
89 changes: 61 additions & 28 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
block_tables_to_ragged,
)
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen
Expand Down Expand Up @@ -254,6 +256,8 @@ def __init__(
trust_remote_code: bool,
**kwargs,
):
if PREFIX_CACHING:
raise NotImplementedError("Vlm do not work with prefix caching yet")
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
Expand Down Expand Up @@ -310,6 +314,9 @@ def forward(
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)

# Add Copy the block tables for all members
block_tables = (
Expand All @@ -330,6 +337,7 @@ def forward(
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices

Expand All @@ -349,43 +357,68 @@ def forward(
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
with self._forward_context(
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
):
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits

# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)

# Replay the graph
cuda_graph["graph"].replay()
Expand Down

0 comments on commit 043c1d1

Please sign in to comment.