@@ -370,9 +370,11 @@ void single_query_cached_kv_attention_launcher(
370370 dim3 block (NUM_THREADS);
371371 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
372372 switch (head_size) {
373- case 32 :
374- LAUNCH_ATTENTION_KERNEL (T, 32 , BLOCK_SIZE, NUM_THREADS);
375- break ;
373+ // NOTE(woosuk): To reduce the compilation time, we omitted head sizes
374+ // 32, 160, 192, 256.
375+ // case 32:
376+ // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
377+ // break;
376378 case 64 :
377379 LAUNCH_ATTENTION_KERNEL (T, 64 , BLOCK_SIZE, NUM_THREADS);
378380 break ;
@@ -385,15 +387,15 @@ void single_query_cached_kv_attention_launcher(
385387 case 128 :
386388 LAUNCH_ATTENTION_KERNEL (T, 128 , BLOCK_SIZE, NUM_THREADS);
387389 break ;
388- case 160 :
389- LAUNCH_ATTENTION_KERNEL (T, 160 , BLOCK_SIZE, NUM_THREADS);
390- break ;
391- case 192 :
392- LAUNCH_ATTENTION_KERNEL (T, 192 , BLOCK_SIZE, NUM_THREADS);
393- break ;
394- case 256 :
395- LAUNCH_ATTENTION_KERNEL (T, 256 , BLOCK_SIZE, NUM_THREADS);
396- break ;
390+ // case 160:
391+ // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
392+ // break;
393+ // case 192:
394+ // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
395+ // break;
396+ // case 256:
397+ // LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
398+ // break;
397399 default :
398400 TORCH_CHECK (false , " Unsupported head size: " , head_size);
399401 break ;
@@ -411,17 +413,19 @@ void single_query_cached_kv_attention_launcher(
411413 context_lens, \
412414 max_context_len);
413415
416+ // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
417+ // 1, 2, 4, 64, 128, 256.
414418#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE (T ) \
415419 switch (block_size) { \
416- case 1 : \
417- CALL_KERNEL_LAUNCHER (T, 1 ); \
418- break ; \
419- case 2 : \
420- CALL_KERNEL_LAUNCHER (T, 2 ); \
421- break ; \
422- case 4 : \
423- CALL_KERNEL_LAUNCHER (T, 4 ); \
424- break ; \
420+ /* case 1: */ \
421+ /* CALL_KERNEL_LAUNCHER(T, 1); */ \
422+ /* break; */ \
423+ /* case 2: */ \
424+ /* CALL_KERNEL_LAUNCHER(T, 2); */ \
425+ /* break; */ \
426+ /* case 4: */ \
427+ /* CALL_KERNEL_LAUNCHER(T, 4); */ \
428+ /* break; */ \
425429 case 8 : \
426430 CALL_KERNEL_LAUNCHER (T, 8 ); \
427431 break ; \
@@ -431,15 +435,15 @@ void single_query_cached_kv_attention_launcher(
431435 case 32 : \
432436 CALL_KERNEL_LAUNCHER (T, 32 ); \
433437 break ; \
434- case 64 : \
435- CALL_KERNEL_LAUNCHER (T, 64 ); \
436- break ; \
437- case 128 : \
438- CALL_KERNEL_LAUNCHER (T, 128 ); \
439- break ; \
440- case 256 : \
441- CALL_KERNEL_LAUNCHER (T, 256 ); \
442- break ; \
438+ /* case 64: */ \
439+ /* CALL_KERNEL_LAUNCHER(T, 64); */ \
440+ /* break; */ \
441+ /* case 128: */ \
442+ /* CALL_KERNEL_LAUNCHER(T, 128); */ \
443+ /* break; */ \
444+ /* case 256: */ \
445+ /* CALL_KERNEL_LAUNCHER(T, 256); */ \
446+ /* break; */ \
443447 default : \
444448 TORCH_CHECK (false , " Unsupported block size: " , block_size); \
445449 break ; \
@@ -455,8 +459,9 @@ void single_query_cached_kv_attention(
455459 torch::Tensor& context_lens, // [num_seqs]
456460 int block_size,
457461 int max_context_len) {
458- // TODO(woosuk): Support FP32.
459- if (query.dtype () == at::ScalarType::Half) {
462+ if (query.dtype () == at::ScalarType::Float) {
463+ CALL_KERNEL_LAUNCHER_BLOCK_SIZE (float );
464+ } else if (query.dtype () == at::ScalarType::Half) {
460465 CALL_KERNEL_LAUNCHER_BLOCK_SIZE (uint16_t );
461466 } else if (query.dtype () == at::ScalarType::BFloat16) {
462467 CALL_KERNEL_LAUNCHER_BLOCK_SIZE (__nv_bfloat16);
0 commit comments