diff --git a/CMakeLists.txt b/CMakeLists.txt index 158174c2094d25..2cc0df3fb2e671 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,6 +88,7 @@ endif() # 3rd party libs option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_BLAS "llama: use BLAS" OFF) +option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ON) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUDA "llama: use CUDA" OFF) option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF) @@ -286,6 +287,7 @@ if (LLAMA_METAL) ${METALKIT_FRAMEWORK} ) endif() + if (LLAMA_BLAS) if (LLAMA_STATIC) set(BLA_STATIC ON) @@ -368,6 +370,10 @@ if (LLAMA_BLAS) endif() endif() +if (LLAMA_LLAMAFILE) + add_compile_definitions(GGML_USE_LLAMAFILE) +endif() + if (LLAMA_QKK_64) add_compile_definitions(GGML_QKK_64) endif() diff --git a/Makefile b/Makefile index 81bd5281d1546d..d0abe93a1f92b2 100644 --- a/Makefile +++ b/Makefile @@ -222,6 +222,8 @@ endif # LLAMA_DISABLE_LOGS # disable ggml.c's use of sgemm.cpp ifdef LLAMA_NO_LLAMAFILE MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=0 +else + MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=1 endif # warnings diff --git a/ggml.c b/ggml.c index 119686be68d9ce..593c603f493be3 100644 --- a/ggml.c +++ b/ggml.c @@ -33,12 +33,8 @@ #include #endif -#ifndef GGML_USE_LLAMAFILE #ifdef __ARM_FEATURE_MATMUL_INT8 -#define GGML_USE_LLAMAFILE 0 -#else -#define GGML_USE_LLAMAFILE 1 -#endif +#undef GGML_USE_LLAMAFILE #endif #if defined(_MSC_VER) @@ -10879,8 +10875,9 @@ UseGgmlGemm1:; if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 + - nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13), + (const char *)wdata + ggml_row_size(vec_dot_type, + nb12/ggml_type_size(src1->type)*i12 + + nb13/ggml_type_size(src1->type)*i13), row_size/ggml_type_size(vec_dot_type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type),