Skip to content

Commit 742af99

Browse files
authored
Support FP32 (vllm-project#141)
1 parent 0e0d818 commit 742af99

File tree

8 files changed

+65
-54
lines changed

8 files changed

+65
-54
lines changed

cacheflow/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def _get_and_verify_dtype(
164164
config_dtype = torch.float32
165165

166166
dtype = dtype.lower()
167-
if dtype == "default":
167+
if dtype == "auto":
168168
if config_dtype == torch.float32:
169169
# Following the common practice, we use float16 for float32 models.
170170
torch_dtype = torch.float16
@@ -184,9 +184,8 @@ def _get_and_verify_dtype(
184184
# Downcasting from float32 to float16 or bfloat16 is allowed.
185185
pass
186186
else:
187-
# Casting between float16 and bfloat16 is not allowed.
188-
raise ValueError(
189-
f"Cannot use {torch_dtype} for {config_dtype} model.")
187+
# Casting between float16 and bfloat16 is allowed with a warning.
188+
logger.warn(f"Casting {config_dtype} to {torch_dtype}.")
190189

191190
# Check if the GPU supports the dtype.
192191
if torch_dtype == torch.bfloat16:

cacheflow/entrypoints/llm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,18 @@ class LLM:
2828
tensor_parallel_size: The number of GPUs to use for distributed
2929
execution with tensor parallelism.
3030
dtype: The data type for the model weights and activations. Currently,
31-
we support `float16` and `bfloat16`. If `default`, we use the
32-
`torch_dtype` attribute of the model config. If the `torch_dtype`
33-
is `float32`, we use `float16` instead.
31+
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
32+
the `torch_dtype` attribute specified in the model config file.
33+
However, if the `torch_dtype` in the config is `float32`, we will
34+
use `float16` instead.
3435
seed: The seed to initialize the random number generator for sampling.
3536
"""
3637

3738
def __init__(
3839
self,
3940
model: str,
4041
tensor_parallel_size: int = 1,
41-
dtype: str = "default",
42+
dtype: str = "auto",
4243
seed: int = 0,
4344
**kwargs,
4445
) -> None:

cacheflow/model_executor/layers/attention.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from cacheflow import pos_encoding_ops
1111
from cacheflow.model_executor.input_metadata import InputMetadata
1212

13-
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
13+
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
1414

1515

1616
class GPTCacheFlowAttention(nn.Module):
@@ -49,10 +49,8 @@ def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
4949
self.attn_op = xops.fmha.cutlass.FwOp()
5050

5151
if self.head_size not in _SUPPORTED_HEAD_SIZES:
52-
raise ValueError(f'head_size ({self.head_size}) is not supported by '
53-
'the single_query_cached_kv_attention kernel. '
54-
'Use one of the following head sizes: '
55-
f'{_SUPPORTED_HEAD_SIZES}.')
52+
raise ValueError(f"head_size ({self.head_size}) is not supported. "
53+
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
5654

5755
def multi_query_kv_attention(
5856
self,

cacheflow/server/arg_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class ServerArgs:
1313
download_dir: Optional[str] = None
1414
use_np_weights: bool = False
1515
use_dummy_weights: bool = False
16-
dtype: str = "default"
16+
dtype: str = "auto"
1717
seed: int = 0
1818
worker_use_ray: bool = False
1919
pipeline_parallel_size: int = 1
@@ -49,9 +49,9 @@ def add_cli_args(
4949
help='use dummy values for model weights')
5050
# TODO(woosuk): Support FP32.
5151
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
52-
choices=['default', 'half', 'bfloat16'],
52+
choices=['auto', 'half', 'bfloat16', 'float'],
5353
help='data type for model weights and activations. '
54-
'The "default" option will use FP16 precision '
54+
'The "auto" option will use FP16 precision '
5555
'for FP32 and FP16 models, and BF16 precision '
5656
'for BF16 models.')
5757
# Parallel arguments
@@ -67,7 +67,7 @@ def add_cli_args(
6767
# KV cache arguments
6868
parser.add_argument('--block-size', type=int,
6969
default=ServerArgs.block_size,
70-
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
70+
choices=[8, 16, 32],
7171
help='token block size')
7272
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
7373
parser.add_argument('--seed', type=int, default=ServerArgs.seed,

csrc/attention/attention_kernels.cu

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,11 @@ void single_query_cached_kv_attention_launcher(
370370
dim3 block(NUM_THREADS);
371371
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
372372
switch (head_size) {
373-
case 32:
374-
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
375-
break;
373+
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
374+
// 32, 160, 192, 256.
375+
// case 32:
376+
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
377+
// break;
376378
case 64:
377379
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
378380
break;
@@ -385,15 +387,15 @@ void single_query_cached_kv_attention_launcher(
385387
case 128:
386388
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
387389
break;
388-
case 160:
389-
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
390-
break;
391-
case 192:
392-
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
393-
break;
394-
case 256:
395-
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
396-
break;
390+
// case 160:
391+
// LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
392+
// break;
393+
// case 192:
394+
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
395+
// break;
396+
// case 256:
397+
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
398+
// break;
397399
default:
398400
TORCH_CHECK(false, "Unsupported head size: ", head_size);
399401
break;
@@ -411,17 +413,19 @@ void single_query_cached_kv_attention_launcher(
411413
context_lens, \
412414
max_context_len);
413415

416+
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
417+
// 1, 2, 4, 64, 128, 256.
414418
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
415419
switch (block_size) { \
416-
case 1: \
417-
CALL_KERNEL_LAUNCHER(T, 1); \
418-
break; \
419-
case 2: \
420-
CALL_KERNEL_LAUNCHER(T, 2); \
421-
break; \
422-
case 4: \
423-
CALL_KERNEL_LAUNCHER(T, 4); \
424-
break; \
420+
/* case 1: */ \
421+
/* CALL_KERNEL_LAUNCHER(T, 1); */ \
422+
/* break; */ \
423+
/* case 2: */ \
424+
/* CALL_KERNEL_LAUNCHER(T, 2); */ \
425+
/* break; */ \
426+
/* case 4: */ \
427+
/* CALL_KERNEL_LAUNCHER(T, 4); */ \
428+
/* break; */ \
425429
case 8: \
426430
CALL_KERNEL_LAUNCHER(T, 8); \
427431
break; \
@@ -431,15 +435,15 @@ void single_query_cached_kv_attention_launcher(
431435
case 32: \
432436
CALL_KERNEL_LAUNCHER(T, 32); \
433437
break; \
434-
case 64: \
435-
CALL_KERNEL_LAUNCHER(T, 64); \
436-
break; \
437-
case 128: \
438-
CALL_KERNEL_LAUNCHER(T, 128); \
439-
break; \
440-
case 256: \
441-
CALL_KERNEL_LAUNCHER(T, 256); \
442-
break; \
438+
/* case 64: */ \
439+
/* CALL_KERNEL_LAUNCHER(T, 64); */ \
440+
/* break; */ \
441+
/* case 128: */ \
442+
/* CALL_KERNEL_LAUNCHER(T, 128); */ \
443+
/* break; */ \
444+
/* case 256: */ \
445+
/* CALL_KERNEL_LAUNCHER(T, 256); */ \
446+
/* break; */ \
443447
default: \
444448
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
445449
break; \
@@ -455,8 +459,9 @@ void single_query_cached_kv_attention(
455459
torch::Tensor& context_lens, // [num_seqs]
456460
int block_size,
457461
int max_context_len) {
458-
// TODO(woosuk): Support FP32.
459-
if (query.dtype() == at::ScalarType::Half) {
462+
if (query.dtype() == at::ScalarType::Float) {
463+
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
464+
} else if (query.dtype() == at::ScalarType::Half) {
460465
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
461466
} else if (query.dtype() == at::ScalarType::BFloat16) {
462467
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);

docs/source/getting_started/installation.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ CacheFlow can run on systems that meet the following requirements:
1818

1919
.. code-block:: console
2020
21+
$ # Pull the Docker image with CUDA 11.8.
2122
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
2223
24+
Inside the Docker container, please execute :code:`pip uninstall torch` before installing CacheFlow.
25+
2326
Install with pip
2427
----------------
2528

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
6666
raise RuntimeError(
6767
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
6868

69+
# Use NVCC threads to parallelize the build.
70+
if nvcc_cuda_version >= Version("11.2"):
71+
num_threads = min(os.cpu_count(), 8)
72+
NVCC_FLAGS += ["--threads", str(num_threads)]
73+
6974
ext_modules = []
7075

7176
# Cache operations.

tests/kernels/test_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ def run_multi_query_kv_attention(
270270
def test_single_query_cached_kv_attention() -> None:
271271
torch.random.manual_seed(TEST_SEED)
272272
torch.cuda.manual_seed(TEST_SEED)
273-
for dtype in [torch.half, torch.bfloat16]:
274-
for block_size in [8, 16, 32, 64]:
275-
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
273+
for dtype in [torch.half, torch.bfloat16, torch.float]:
274+
for block_size in [8, 16, 32]:
275+
for head_size in [64, 80, 96, 128]:
276276
print(f'Testing single_query_cached_kv_attention with '
277277
f'dtype={dtype}, block_size={block_size}, '
278278
f'head_size={head_size}')
@@ -289,8 +289,8 @@ def test_single_query_cached_kv_attention() -> None:
289289
def test_multi_query_kv_attention() -> None:
290290
torch.random.manual_seed(TEST_SEED)
291291
torch.cuda.manual_seed(TEST_SEED)
292-
for dtype in [torch.half, torch.bfloat16]:
293-
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
292+
for dtype in [torch.half, torch.bfloat16, torch.float]:
293+
for head_size in [64, 80, 96, 128]:
294294
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
295295
f'head_size={head_size}')
296296
run_multi_query_kv_attention(

0 commit comments

Comments
 (0)