Skip to content

Commit

Permalink
CUDA: quantized KV support for FA vec (ggerganov#7527)
Browse files Browse the repository at this point in the history
* CUDA: quantized KV support for FA vec

* try CI fix

* fix commented-out kernel variants

* add q8_0 q4_0 tests

* fix nwarps > batch size

* split fattn compile via extern templates

* fix flake8

* fix metal tests

* fix cmake

* make generate_cu_files.py executable

* add autogenerated .cu files

* fix AMD

* error if type_v != FP16 and not flash_attn

* remove obsolete code
  • Loading branch information
JohannesGaessler committed Jun 1, 2024
1 parent a323ec6 commit 9b59641
Show file tree
Hide file tree
Showing 110 changed files with 2,649 additions and 1,152 deletions.
30 changes: 30 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"llama: max. batch size for using peer access")
option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF)
option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF)
option(LLAMA_CUDA_FA_ALL_QUANTS "llama: compile all quants for FlashAttention" OFF)

option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
Expand Down Expand Up @@ -402,6 +403,8 @@ if (LLAMA_CUDA)

file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu")
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})

add_compile_definitions(GGML_USE_CUDA)
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
Expand All @@ -427,6 +430,18 @@ if (LLAMA_CUDA)
if (LLAMA_CUDA_NO_PEER_COPY)
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
endif()
if (LLAMA_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
endif()

if (LLAMA_STATIC)
if (WIN32)
Expand Down Expand Up @@ -571,6 +586,8 @@ if (LLAMA_HIPBLAS)

file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu")
list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})

add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)

Expand All @@ -590,6 +607,19 @@ if (LLAMA_HIPBLAS)
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
endif()

if (LLAMA_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
endif()

add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
Expand Down
21 changes: 18 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,15 @@ ifdef LLAMA_CUBLAS
LLAMA_CUDA := 1
endif

OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-wmma*.cu))
ifdef LLAMA_CUDA_FA_ALL_QUANTS
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*.cu))
else
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu))
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu))
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*f16-f16.cu))
endif # LLAMA_CUDA_FA_ALL_QUANTS

ifdef LLAMA_CUDA
ifneq ('', '$(wildcard /opt/cuda)')
CUDA_PATH ?= /opt/cuda
Expand All @@ -431,6 +440,7 @@ ifdef LLAMA_CUDA
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
OBJS += ggml-cuda.o
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
OBJS += $(OBJS_CUDA_TEMP_INST)
MK_NVCCFLAGS += -use_fast_math
ifdef LLAMA_FATAL_WARNINGS
MK_NVCCFLAGS += -Werror all-warnings
Expand Down Expand Up @@ -493,7 +503,10 @@ ifdef LLAMA_CUDA_NO_PEER_COPY
endif # LLAMA_CUDA_NO_PEER_COPY
ifdef LLAMA_CUDA_CCBIN
MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
endif
endif # LLAMA_CUDA_CCBIN
ifdef LLAMA_CUDA_FA_ALL_QUANTS
MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
endif # LLAMA_CUDA_FA_ALL_QUANTS

ifdef JETSON_EOL_MODULE_DETECT
define NVCC_COMPILE
Expand All @@ -505,7 +518,7 @@ define NVCC_COMPILE
endef # NVCC_COMPILE
endif # JETSON_EOL_MODULE_DETECT

ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
ggml-cuda/%.o: ggml-cuda/%.cu ggml.h ggml-common.h ggml-cuda/common.cuh
$(NVCC_COMPILE)

ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
Expand Down Expand Up @@ -585,11 +598,12 @@ ifdef LLAMA_CUDA_NO_PEER_COPY
endif # LLAMA_CUDA_NO_PEER_COPY
OBJS += ggml-cuda.o
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
OBJS += $(OBJS_CUDA_TEMP_INST)

ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<

ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
ggml-cuda/%.o: ggml-cuda/%.cu ggml.h ggml-common.h ggml-cuda/common.cuh
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<

endif # LLAMA_HIPBLAS
Expand Down Expand Up @@ -749,6 +763,7 @@ libllama.a: llama.o ggml.o $(OBJS) $(COMMON_DEPS)
clean:
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
rm -vrf ggml-cuda/*.o
rm -vrf ggml-cuda/template-instances/*.o
find examples pocs -type f -name "*.o" -delete

#
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ Building the program with BLAS support may lead to some performance improvements
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
| LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
| LLAMA_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |

- #### hipBLAS

Expand Down
8 changes: 6 additions & 2 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2905,10 +2905,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
#else
if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
if (op->src[0]->ne[0] == 128) {
return true;
}
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
default:
return false;
Expand Down
Loading

0 comments on commit 9b59641

Please sign in to comment.