Skip to content

Commit

Permalink
sync : ggml (Metal F32 support + reduce ggml-alloc size) (#3192)
Browse files Browse the repository at this point in the history
* sync : ggml (Metal F32 support + reduce ggml-alloc size)

ggml-ci

* llama-bench : fix ggml_cpu_has_metal() duplicate function

ggml-ci
  • Loading branch information
ggerganov authored Sep 15, 2023
1 parent 7e50d34 commit 8c00b7a
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 91 deletions.
8 changes: 0 additions & 8 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ static T stdev(const std::vector<T> & v) {
return stdev;
}

static bool ggml_cpu_has_metal() {
#if defined(GGML_USE_METAL)
return true;
#else
return false;
#endif
}

static std::string get_cpu_info() {
std::string id;
#ifdef __linux__
Expand Down
12 changes: 6 additions & 6 deletions ggml-alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_ten
return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
}

static bool ggml_is_view(struct ggml_tensor * t) {
return t->view_src != NULL;
}

void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
#ifdef GGML_ALLOCATOR_DEBUG
GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
Expand Down Expand Up @@ -338,8 +342,8 @@ static void free_vmem(void * base_addr, size_t size) {

// allocate uncommitted virtual memory to measure the size of the graph
static void alloc_measure_vmem(void ** base_addr, size_t * size) {
// 1TB for 64-bit, 1GB for 32-bit
*size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<40;
// 128GB for 64-bit, 1GB for 32-bit
*size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
do {
*base_addr = alloc_vmem(*size);
if (*base_addr != NULL) {
Expand Down Expand Up @@ -399,10 +403,6 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {

//////////// compute graph allocator

static bool ggml_is_view(struct ggml_tensor * t) {
return t->view_src != NULL;
}

static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
if (a->type != b->type) {
return false;
Expand Down
12 changes: 12 additions & 0 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
Expand All @@ -89,6 +90,7 @@
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
Expand Down Expand Up @@ -237,6 +239,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
Expand All @@ -248,6 +251,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
Expand Down Expand Up @@ -309,6 +313,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
Expand All @@ -320,6 +325,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
Expand Down Expand Up @@ -885,6 +891,7 @@ void ggml_metal_graph_compute(
ne00%32 == 0 &&
ne11 > 1) {
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
Expand Down Expand Up @@ -919,6 +926,11 @@ void ggml_metal_graph_compute(

// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_F32:
{
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
nrows = 4;
} break;
case GGML_TYPE_F16:
{
nth0 = 32;
Expand Down
89 changes: 81 additions & 8 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32(
}
}

#define N_F32_F32 4

kernel void kernel_mul_mat_f32_f32(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {

const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F32_F32;
const int64_t im = tgpig.z;

device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);

if (ne00 < 128) {
for (int row = 0; row < N_F32_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}

device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);

float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}

float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const float4 * x4 = (device const float4 *)x;
for (int row = 0; row < N_F32_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}

device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y;

float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
}

float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}

kernel void kernel_mul_mat_f16_f32_1row(
device const char * src0,
device const char * src1,
Expand Down Expand Up @@ -1399,13 +1472,13 @@ kernel void kernel_mul_mat_q4_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]],
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
constant int64_t & ne01 [[buffer(4)]],
constant int64_t & ne02 [[buffer(5)]],
constant int64_t & ne10 [[buffer(9)]],
constant int64_t & ne12 [[buffer(11)]],
constant int64_t & ne0 [[buffer(15)]],
constant int64_t & ne1 [[buffer(16)]],
constant uint & gqa [[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
Expand Down Expand Up @@ -2012,7 +2085,6 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}

}

template <typename type4x4>
Expand Down Expand Up @@ -2269,6 +2341,7 @@ typedef void (mat_mm_t)(
constant uint & gqa,
threadgroup uchar *, uint3, uint, uint);

template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
Expand Down
Loading

0 comments on commit 8c00b7a

Please sign in to comment.