Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support e4m3fn KV cache #2655

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/reference/launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ Options:
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA

[env: KV_CACHE_DTYPE=]
[possible values: fp8_e5m2]
[possible values: fp8_e4m3fn, fp8_e5m2]

```
## TRUST_REMOTE_CODE
Expand Down
8 changes: 7 additions & 1 deletion launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,19 @@ impl std::fmt::Display for Dtype {

#[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype {
#[clap(name = "fp8_e4m3fn")]
Fp8e4m3fn,

#[clap(name = "fp8_e5m2")]
Fp8e5m2,
}

impl std::fmt::Display for KVCacheDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KVCacheDtype::Fp8e4m3fn => {
write!(f, "fp8_e4m3fn")
}
KVCacheDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2")
}
Expand Down Expand Up @@ -420,7 +426,7 @@ struct Args {

/// Specify the dtype for the key-value cache. When this option is not provided,
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
/// the only supported value is `fp8_e5m2` on CUDA.
/// the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA.
#[clap(long, env, value_enum)]
kv_cache_dtype: Option<KVCacheDtype>,

Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Dtype(str, Enum):


class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2"


Expand Down
8 changes: 4 additions & 4 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def __init__(
):
"""Construct the key-value cache for a layer."""

if dtype == torch.float8_e5m2 and (
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
ATTENTION != "flashinfer" or SYSTEM != "cuda"
):
raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
"FP8 KV cache is currently only supported for flashinfer on CUDA"
)

element_size = torch.tensor([], dtype=dtype).element_size()
Expand Down Expand Up @@ -105,8 +105,8 @@ def store(
# TODO: add scale
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype == torch.float8_e5m2:
# Torch index_put does not support float8_e5m2 yet, so
if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
# put as raw data instead.
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ def get_model(

if kv_cache_dtype is None:
kv_cache_dtype = dtype
elif kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
Expand Down
Loading