Skip to content

Commit

Permalink
ggml : add llamafile sgemm (ggerganov#6414)
Browse files Browse the repository at this point in the history
This change upstreams llamafile's cpu matrix multiplication kernels
which improve image and prompt evaluation speed. For starters, Q4_0
and Q8_0 weights should go ~40% faster on CPU. The biggest benefits
are with data types like f16 / f32, which process prompts 2x faster
thus making them faster than quantized data types for prompt evals.

This change also introduces bona fide AVX512 support since tinyBLAS
is able to exploit the larger register file. For example, on my CPU
llama.cpp llava-cli processes an image prompt at 305 tokens/second,
using the Q4_K and Q4_0 types, which has always been faster than if
we used f16 LLaVA weights, which at HEAD go 188 tokens/second. With
this change, f16 LLaVA performance leap frogs to 464 tokens/second.

On Intel Core i9-14900K this change improves F16 prompt perf by 5x.
For example, using llama.cpp at HEAD with Mistral 7b f16 to process
a 215 token prompt will go 13 tok/sec. This change has fixes making
it go 52 tok/sec. It's mostly thanks to my vectorized outer product
kernels but also because I added support for correctly counting the
number of cores on Alderlake, so the default thread count discounts
Intel's new efficiency cores. Only Linux right now can count cores.

This work was sponsored by Mozilla who's given permission to change
the license of this code from Apache 2.0 to MIT. To read more about
what's improved, and how it works, see: https://justine.lol/matmul/
  • Loading branch information
jart authored Apr 16, 2024
1 parent dbceec8 commit 8cc91dc
Show file tree
Hide file tree
Showing 12 changed files with 1,312 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,8 @@ add_library(ggml OBJECT
ggml-backend.h
ggml-quants.c
ggml-quants.h
sgemm.cpp
sgemm.h
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
Expand Down
10 changes: 9 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ ifdef LLAMA_DISABLE_LOGS
MK_CPPFLAGS += -DLOG_DISABLE_LOGS
endif # LLAMA_DISABLE_LOGS

# disable ggml.c's use of sgemm.cpp
ifdef LLAMA_NO_LLAMAFILE
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=0
endif

# warnings
WARN_FLAGS = -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
MK_CFLAGS += $(WARN_FLAGS) -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int \
Expand Down Expand Up @@ -676,13 +681,16 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h ggml-common.h
$(CC) $(CFLAGS) -c $< -o $@

sgemm.o: sgemm.cpp sgemm.h ggml.h
$(CXX) $(CXXFLAGS) -c $< -o $@

unicode.o: unicode.cpp unicode.h
$(CXX) $(CXXFLAGS) -c $< -o $@

unicode-data.o: unicode-data.cpp unicode-data.h
$(CXX) $(CXXFLAGS) -c $< -o $@

OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o sgemm.o

llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@
Expand Down
1 change: 1 addition & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import PackageDescription

var sources = [
"ggml.c",
"sgemm.cpp",
"llama.cpp",
"unicode.cpp",
"unicode-data.cpp",
Expand Down
15 changes: 8 additions & 7 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ pub fn build(b: *std.build.Builder) !void {
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;

const ggml = make.obj("ggml", "ggml.c");
const sgemm = make.obj("sgemm", "sgemm.cpp");
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
const ggml_quants = make.obj("ggml-quants", "ggml-quants.c");
Expand All @@ -128,14 +129,14 @@ pub fn build(b: *std.build.Builder) !void {
const clip = make.obj("clip", "examples/llava/clip.cpp");
const llava = make.obj("llava", "examples/llava/llava.cpp");

_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });

const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava });
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava });
if (server.target.isWindows()) {
server.linkSystemLibrary("ws2_32");
}
Expand Down
73 changes: 73 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,79 @@ int32_t get_num_physical_cores() {
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
}

#if defined(__x86_64__) && defined(__linux__)
#include <pthread.h>

static void cpuid(unsigned leaf, unsigned subleaf,
unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) {
__asm__("movq\t%%rbx,%%rsi\n\t"
"cpuid\n\t"
"xchgq\t%%rbx,%%rsi"
: "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx)
: "0"(leaf), "2"(subleaf));
}

static int pin_cpu(int cpu) {
cpu_set_t mask;
CPU_ZERO(&mask);
CPU_SET(cpu, &mask);
return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask);
}

static bool is_hybrid_cpu(void) {
unsigned eax, ebx, ecx, edx;
cpuid(7, 0, &eax, &ebx, &ecx, &edx);
return !!(edx & (1u << 15));
}

static bool is_running_on_efficiency_core(void) {
unsigned eax, ebx, ecx, edx;
cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx);
int intel_atom = 0x20;
int core_type = (eax & 0xff000000u) >> 24;
return core_type == intel_atom;
}

static int count_math_cpus(int cpu_count) {
int result = 0;
for (int cpu = 0; cpu < cpu_count; ++cpu) {
if (pin_cpu(cpu)) {
return -1;
}
if (is_running_on_efficiency_core()) {
continue; // efficiency cores harm lockstep threading
}
++cpu; // hyperthreading isn't useful for linear algebra
++result;
}
return result;
}

#endif // __x86_64__ && __linux__

/**
* Returns number of CPUs on system that are useful for math.
*/
int get_math_cpu_count() {
#if defined(__x86_64__) && defined(__linux__)
int cpu_count = sysconf(_SC_NPROCESSORS_ONLN);
if (cpu_count < 1) {
return get_num_physical_cores();
}
if (is_hybrid_cpu()) {
cpu_set_t affinity;
if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) {
int result = count_math_cpus(cpu_count);
pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity);
if (result > 0) {
return result;
}
}
}
#endif
return get_num_physical_cores();
}

void process_escapes(std::string & input) {
std::size_t input_len = input.length();
std::size_t output_idx = 0;
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ extern char const *LLAMA_BUILD_TARGET;

struct llama_control_vector_load_info;

int get_math_cpu_count();
int32_t get_num_physical_cores();

//
Expand All @@ -48,7 +49,7 @@ int32_t get_num_physical_cores();
struct gpt_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed

int32_t n_threads = get_num_physical_cores();
int32_t n_threads = get_math_cpu_count();
int32_t n_threads_draft = -1;
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_threads_batch_draft = -1;
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ static const cmd_params cmd_params_defaults = {
/* n_ubatch */ {512},
/* type_k */ {GGML_TYPE_F16},
/* type_v */ {GGML_TYPE_F16},
/* n_threads */ {get_num_physical_cores()},
/* n_threads */ {get_math_cpu_count()},
/* n_gpu_layers */ {99},
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
/* main_gpu */ {0},
Expand Down
2 changes: 1 addition & 1 deletion ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ typedef uint16_t ggml_fp16_internal_t;
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h>
#else
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
#if !defined(__riscv)
#include <immintrin.h>
#endif
Expand Down
2 changes: 1 addition & 1 deletion ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
}

static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
#if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
return _mm256_cvtepi32_ps(summed_pairs);
Expand Down
54 changes: 54 additions & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "ggml-impl.h"
#include "ggml-quants.h"
#include "ggml.h"
#include "sgemm.h"

#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
Expand Down Expand Up @@ -32,6 +33,14 @@
#include <unistd.h>
#endif

#ifndef GGML_USE_LLAMAFILE
#ifdef __ARM_FEATURE_MATMUL_INT8
#define GGML_USE_LLAMAFILE 0
#else
#define GGML_USE_LLAMAFILE 1
#endif
#endif

#if defined(_MSC_VER)
// disable "possible loss of data" to avoid hundreds of casts
// we should just be careful :)
Expand Down Expand Up @@ -10810,6 +10819,28 @@ static void ggml_compute_forward_mul_mat(
}
#endif

#if GGML_USE_LLAMAFILE
if (nb10 == ggml_type_size(src1->type)) {
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
(const char *)src1->data + i12*nb12 + i13*nb13,
nb11/ggml_type_size(src1->type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),
ith, nth,
params->type,
src0->type,
src1->type,
dst->type))
goto UseGgmlGemm1;
return;
}
UseGgmlGemm1:;
#endif

if (params->type == GGML_TASK_TYPE_INIT) {
if (ith != 0) {
return;
Expand Down Expand Up @@ -10841,6 +10872,29 @@ static void ggml_compute_forward_mul_mat(
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);

#if GGML_USE_LLAMAFILE
if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) {
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
(const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 +
nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13),
row_size/ggml_type_size(vec_dot_type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),
ith, nth,
params->type,
src0->type,
vec_dot_type,
dst->type))
goto UseGgmlGemm2;
return;
}
UseGgmlGemm2:;
#endif

const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = ne1*ne12*ne13; // src1 rows

Expand Down
Loading

0 comments on commit 8cc91dc

Please sign in to comment.