Skip to content

[WIP] Llama 4: Hybrid KV buffer (disable radix attention) #5853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions log.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
self.is_hybrid:None, self.local_max_num_tokens:None, self.global_max_num_tokens:10486626.self.is_hybrid:None, self.local_max_num_tokens:None, self.global_max_num_tokens:10486626.self.is_hybrid:None, self.local_max_num_tokens:None, self.global_max_num_tokens:10486626.self.is_hybrid:None, self.local_max_num_tokens:None, self.global_max_num_tokens:10486626.
1 change: 1 addition & 0 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def load_model(server_args, port_args, tp_rank):
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
enable_hybrid_kvcache=server_args.enable_hybrid_kvcache,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
Expand Down
23 changes: 23 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
model_override_args: Optional[str] = None,
is_embedding: Optional[bool] = None,
enable_multimodal: Optional[bool] = None,
enable_hybrid_kvcache: Optional[float] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
Expand Down Expand Up @@ -86,6 +87,12 @@ def __init__(
enable_multimodal = True

# Check model type
self.is_hybrid = is_hybrid_model(
self.hf_config.architectures,
enable_hybrid_kvcache=enable_hybrid_kvcache,
context_length=context_length,
attention_chunk_size=self.attention_chunk_size
)
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
Expand Down Expand Up @@ -525,6 +532,22 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
"CLIPModel",
]

def is_hybrid_model(
model_architectures: List[str],
enable_hybrid_kvcache: Optional[float],
context_length: Optional[int],
attention_chunk_size: Optional[int]
):

if enable_hybrid_kvcache is None:
return None
elif(enable_hybrid_kvcache > 0
and model_architectures[0] == "Llama4ForConditionalGeneration"
and context_length > attention_chunk_size
):
return enable_hybrid_kvcache
else:
return None

def is_multimodal_model(model_architectures: List[str]):
if any(
Expand Down
63 changes: 48 additions & 15 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class FlashAttentionMetadata:
window_size: tuple = (-1, -1)
# Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None
# Page table, the index of KV Cache Tables/Blocks for local attention
page_table_local: torch.Tensor = None

# Encoder metadata
# Cumulative sequence lengths for encoder key
Expand Down Expand Up @@ -315,6 +317,8 @@ def __init__(
self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.req_to_token_local = model_runner.req_to_token_pool.req_to_token_local
self.is_hybrid = model_runner.is_hybrid
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
self.page_size = model_runner.page_size
Expand Down Expand Up @@ -427,6 +431,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if self.is_hybrid is not None:
metadata.page_table_local = forward_batch.req_to_token_pool.req_to_token_local[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(metadata, device)
elif forward_batch.forward_mode.is_target_verify():
Expand Down Expand Up @@ -562,6 +570,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if self.is_hybrid is not None:
metadata.page_table_local = forward_batch.req_to_token_pool.req_to_token_local[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]

if (
any(forward_batch.extend_prefix_lens_cpu)
Expand Down Expand Up @@ -627,14 +639,22 @@ def forward_extend(
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
use_hybrid_loc = (
self.is_hybrid is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not use_hybrid_loc:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
else:
cache_loc = forward_batch.out_cache_loc_local
# TODO enable cross attention
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
Expand All @@ -649,7 +669,7 @@ def forward_extend(

# Use precomputed metadata across all layers
metadata = self.forward_metadata

# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
Expand All @@ -667,14 +687,13 @@ def forward_extend(
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
causal = not layer.is_cross_attention

# Check if we should use local attention
use_local_attn = (
self.attention_chunk_size is not None
and metadata.local_attn_metadata is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)

# We do cascade attention for Target Verify with topk > 1
use_cascade_attn = (
forward_batch.forward_mode.is_target_verify() and self.topk > 1
Expand Down Expand Up @@ -890,14 +909,24 @@ def forward_decode(
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:

use_hybrid_loc = (
self.is_hybrid is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)

if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not use_hybrid_loc:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
else:
cache_loc = forward_batch.out_cache_loc_local
# TODO enable cross attention
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
Expand All @@ -909,12 +938,13 @@ def forward_decode(
k,
k_rope,
)

# Use precomputed metadata across all layers
metadata = self.forward_metadata
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
use_local_attention = (
self.attention_chunk_size is not None and local_attn_metadata is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
)
# We do cascade attention for Draft Decode with topk > 1
use_cascade_attn = self.topk > 1
Expand Down Expand Up @@ -1757,7 +1787,10 @@ def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):

cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens_int32 = metadata.cache_seqlens_int32
page_table = metadata.page_table
if self.is_hybrid:
page_table = metadata.page_table_local
else:
page_table = metadata.page_table
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
metadata.local_attn_metadata = None
return
Expand Down
49 changes: 48 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
reqs: List[Req]
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
token_to_kv_pool_allocator_local: TokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None

# Batch configs
Expand All @@ -739,6 +740,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache
out_cache_loc: torch.Tensor = None # shape: [b], int64
out_cache_loc_local: torch.Tensor = None
output_ids: torch.Tensor = None # shape: [b], int64

# The sum of all sequence lengths
Expand Down Expand Up @@ -803,13 +805,15 @@ def init_new(
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
token_to_kv_pool_allocator_local: TokenToKVPoolAllocator = None,
):
return_logprob = any(req.return_logprob for req in reqs)

return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
token_to_kv_pool_allocator_local=token_to_kv_pool_allocator_local,
tree_cache=tree_cache,
model_config=model_config,
enable_overlap=enable_overlap,
Expand Down Expand Up @@ -838,6 +842,12 @@ def alloc_req_slots(self, num_reqs: int):
f"{num_reqs=}, "
)
return req_pool_indices

def alloc_token_slots_local(self, num_tokens: int, backup_state: bool = False):
# TODO backup_state
out_cache_loc_local = self.token_to_kv_pool_allocator_local.alloc(num_tokens)
# if out_cache_loc is None:
return out_cache_loc_local

def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
Expand Down Expand Up @@ -1053,6 +1063,10 @@ def prepare_for_extend(self):
self.req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
if self.token_to_kv_pool_allocator_local is not None:
self.req_to_token_pool.write_local(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)

# If input_embeds are available, store them
if req.input_embeds is not None:
Expand Down Expand Up @@ -1121,7 +1135,10 @@ def prepare_for_extend(self):
# Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1:
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
if self.token_to_kv_pool_allocator_local is not None:
out_cache_loc_local = self.alloc_token_slots_local(extend_num_tokens)
else:
# TODO: page_size != 1
last_loc = get_last_loc(
self.req_to_token_pool.req_to_token,
req_pool_indices_tensor,
Expand All @@ -1136,6 +1153,8 @@ def prepare_for_extend(self):
self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor
self.out_cache_loc = out_cache_loc
if self.token_to_kv_pool_allocator_local is not None:
self.out_cache_loc_local = out_cache_loc_local
self.input_embeds = (
torch.tensor(input_embeds).to(self.device, non_blocking=True)
if input_embeds
Expand Down Expand Up @@ -1166,15 +1185,30 @@ def prepare_for_extend(self):
out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
)
if self.token_to_kv_pool_allocator_local is not None:
write_req_to_token_pool_triton[(bs,)](
self.req_to_token_pool.req_to_token_local,
req_pool_indices_tensor,
prefix_lens_tensor,
seq_lens_tensor,
extend_lens_tensor,
out_cache_loc_local,
self.req_to_token_pool.req_to_token_local.shape[1],
)

else:
pt = 0
for i in range(bs):
self.req_to_token_pool.write(
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
out_cache_loc[pt : pt + extend_lens[i]],
)
if self.token_to_kv_pool_allocator_local is not None:
self.req_to_token_pool.write_local(
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
out_cache_loc_local[pt : pt + extend_lens[i]],
)
pt += extend_lens[i]

if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens)

Expand Down Expand Up @@ -1338,6 +1372,7 @@ def prepare_for_idle(self):
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
self.out_cache_loc_local = torch.empty(0, dtype=torch.int64, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
Expand Down Expand Up @@ -1399,17 +1434,25 @@ def prepare_for_decode(self):
# Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1:
self.out_cache_loc = self.alloc_token_slots(bs)
if self.token_to_kv_pool_allocator_local is not None:
self.out_cache_loc_local = self.alloc_token_slots_local(bs)
else:
# TODO: page_size != 1
last_loc = self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens - 2
]
self.out_cache_loc = self.alloc_paged_token_slots_decode(
self.seq_lens, last_loc
)

# TODO: write out_cache_loc_local to req_to_token_pool
self.req_to_token_pool.write(
(self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
)
if self.token_to_kv_pool_allocator_local is not None:
self.req_to_token_pool.write_local(
(self.req_pool_indices, locs), self.out_cache_loc_local.to(torch.int32)
)

def filter_batch(
self,
Expand Down Expand Up @@ -1536,6 +1579,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
out_cache_loc_local=self.out_cache_loc_local,
seq_lens_sum=self.seq_lens_sum,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
Expand Down Expand Up @@ -1580,6 +1624,7 @@ def copy(self):
model_config=self.model_config,
forward_mode=self.forward_mode,
out_cache_loc=self.out_cache_loc,
out_cache_loc_local=self.out_cache_loc_local,
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
Expand Down Expand Up @@ -1608,6 +1653,8 @@ class ModelWorkerBatch:
seq_lens_cpu: Optional[torch.Tensor]
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor
# The indices of output tokens in the token_to_kv_pool_allocator_local
out_cache_loc_local: torch.Tensor

# The sum of all sequence lengths
seq_lens_sum: int
Expand Down
Loading
Loading