Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def forward(self, q, k, v, input_metadata: InputMetadata):
return self.decode_forward(q, k, v, input_metadata)

def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
k_cache[input_metadata.out_cache_loc] = cache_k
v_cache[input_metadata.out_cache_loc] = cache_v
input_metadata.token_to_kv_pool.set_kv_buffer(
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
)
90 changes: 82 additions & 8 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""Memory pool."""

import logging
from typing import List, Union
from abc import ABC, abstractmethod
from typing import List, Tuple, Union

import torch

Expand Down Expand Up @@ -52,14 +53,21 @@ def clear(self):
self.free_slots = list(range(self.size))


class BaseTokenToKVPool:
class BaseTokenToKVPool(ABC):
"""A memory pool that maps a token to its kv cache locations"""

def __init__(
self,
size: int,
dtype: torch.dtype,
):
self.size = size
self.dtype = dtype
if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype

# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
Expand Down Expand Up @@ -112,6 +120,28 @@ def clear(self):
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state[0] = False

@abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()

@abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()

@abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()

@abstractmethod
def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
) -> None:
raise NotImplementedError()


class MHATokenToKVPool(BaseTokenToKVPool):

Expand All @@ -123,26 +153,52 @@ def __init__(
head_dim: int,
layer_num: int,
):
super().__init__(size)
super().__init__(size, dtype)

# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
)
for _ in range(layer_num)
]

def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id]

def get_value_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id]

def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)

def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if cache_v.dtype != self.dtype:
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workaround for float8_e5m2

Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2

self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
else:
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v


class MLATokenToKVPool(BaseTokenToKVPool):
Expand All @@ -155,23 +211,41 @@ def __init__(
qk_rope_head_dim: int,
layer_num: int,
):
super().__init__(size)
super().__init__(size, dtype)

self.kv_lora_rank = kv_lora_rank
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=dtype,
dtype=self.store_dtype,
device="cuda",
)
for _ in range(layer_num)
]

def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id].view(self.dtype)
return self.kv_buffer[layer_id]

def get_value_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]

def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)

def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
else:
self.kv_buffer[layer_id][loc] = cache_k
4 changes: 4 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ def update_flashinfer_indices(
num_kv_heads,
head_dim,
1,
data_type=model_runner.kv_cache_dtype,
q_data_type=model_runner.dtype,
)
else:
# extend part
Expand Down Expand Up @@ -393,6 +395,8 @@ def update_flashinfer_indices(
num_kv_heads,
head_dim,
1,
data_type=model_runner.kv_cache_dtype,
q_data_type=model_runner.dtype,
)
else:
# extend part
Expand Down
23 changes: 19 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,15 @@ def profile_max_num_token(self, total_gpu_memory: int):
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* self.model_config.num_hidden_layers
* torch._utils._element_size(self.dtype)
* torch._utils._element_size(self.kv_cache_dtype)
)
else:
cell_size = (
self.model_config.get_num_kv_heads(self.tp_size)
* self.model_config.head_dim
* self.model_config.num_hidden_layers
* 2
* torch._utils._element_size(self.dtype)
* torch._utils._element_size(self.kv_cache_dtype)
)
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
Expand All @@ -334,6 +334,21 @@ def init_memory_pool(
max_num_reqs: int = None,
max_total_tokens: int = None,
):
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, only FlashInfer is supported and not Triton, due to the issue of insufficient smem. This needs to be fixed in another PR.

logger.warning(
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
)
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = torch.float8_e5m2
else:
raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
)

self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens:
Expand Down Expand Up @@ -370,7 +385,7 @@ def init_memory_pool(
):
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
dtype=self.dtype,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
Expand All @@ -381,7 +396,7 @@ def init_memory_pool(
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
dtype=self.dtype,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ServerArgs:
skip_tokenizer_init: bool = False
load_format: str = "auto"
dtype: str = "auto"
kv_cache_dtype: str = "auto"
trust_remote_code: bool = True
context_length: Optional[int] = None
quantization: Optional[str] = None
Expand Down Expand Up @@ -196,6 +197,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.',
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
default=ServerArgs.kv_cache_dtype,
choices=["auto", "fp8_e5m2"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
Expand Down