Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : Metal inference #1642

Merged
merged 49 commits into from
Jun 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f85020b
mtl : export the LLaMA computation graph
ggerganov May 29, 2023
98c267f
ci : disable temporary
ggerganov May 29, 2023
b23fe8c
mtl : adapt the MNIST example as starter
ggerganov May 29, 2023
a792cbd
mtl : no need for mtl-export tool, add cli arg for main instead
ggerganov May 29, 2023
897d6d8
mtl : export just a small part of the graph for now to make it easier
ggerganov May 29, 2023
248a8c3
mtl : move MSL code into separate file for easy editing
ggerganov May 29, 2023
a8fd9dc
mtl : initial get_rows_q4_0 kernel
ggerganov May 29, 2023
794704e
mtl : confirmed get_rows_q4_0 is working correctly
ggerganov May 30, 2023
72256eb
mtl : add rms_norm kernel + confirm working
ggerganov May 30, 2023
64afc0b
mtl : add mul kernel + confirm working
ggerganov May 30, 2023
2a24994
mtl : initial mul_mat Q4 kernel (wrong results)
ggerganov May 30, 2023
96d0052
mtl : mul_mat fixes (still wrong)
ggerganov May 30, 2023
29bec00
mtl : another mul_mat Q4 (still does not work)
ggerganov May 30, 2023
b2fd06c
mtl : working mul_mat q4
ggerganov May 30, 2023
6af6a05
ggml : fix handling of "view" ops in ggml_graph_import()
ggerganov May 31, 2023
1213af7
mtl : add rope kernel
ggerganov May 31, 2023
7ca81e9
mtl : add reshape and transpose handling
ggerganov May 31, 2023
94ea9e7
ggml : store offset as opt arg for ggml_view_xd() operators
ggerganov Jun 1, 2023
948fcfd
mtl : add cpy kernel + handle view ops
ggerganov Jun 1, 2023
51efb59
mtl : confirm f16 x f32 attention mul mat
ggerganov Jun 1, 2023
0f1c580
mtl : add scale kernel
ggerganov Jun 1, 2023
17a7036
mtl : add diag_mask_inf kernel
ggerganov Jun 1, 2023
17930fb
mtl : fix soft_max kernel
ggerganov Jun 1, 2023
f67c2d8
ggml : update ggml_nbytes() to handle non-contiguous tensors
ggerganov Jun 1, 2023
a266c26
mtl : verify V tensor contents
ggerganov Jun 1, 2023
a0cc3de
mtl : add f32 -> f32 cpy kernel
ggerganov Jun 1, 2023
42dca40
mtl : add silu kernel
ggerganov Jun 1, 2023
fbd3f62
mtl : add non-broadcast mul kernel
ggerganov Jun 1, 2023
9665429
mtl : full GPU inference of the computation graph
ggerganov Jun 1, 2023
f0196a7
mtl : optimize rms_norm and soft_max kernels
ggerganov Jun 1, 2023
e55f7b0
mtl : add f16 mat x f32 vec multiplication kernel
ggerganov Jun 1, 2023
3367146
mtl : fix bug in f16 x f32 mul mat + speed-up computation
ggerganov Jun 2, 2023
847bbfe
mtl : faster mul_mat_q4_0_f32 kernel
ggerganov Jun 2, 2023
70c3387
mtl : fix kernel signature + roll inner loop
ggerganov Jun 2, 2023
b088e14
mtl : more threads for rms_norm + better timing
ggerganov Jun 2, 2023
6276057
mtl : remove printfs from inner loop
ggerganov Jun 2, 2023
03c2d72
mtl : simplify implementation
ggerganov Jun 2, 2023
640a889
mtl : add save/load vocab to ggml file
ggerganov Jun 2, 2023
2f4e9d1
mtl : plug Metal inference into llama.cpp (very quick-n-dirty)
ggerganov Jun 2, 2023
4df2ef3
mtl : make it work with main example
ggerganov Jun 3, 2023
18e482a
mtl : preparing for merge
ggerganov Jun 4, 2023
e4b5222
mtl : clean-up ggml mtl interface + suport scratch / inplace
ggerganov Jun 4, 2023
e26cd6b
mtl : remove temp / debug code
ggerganov Jun 4, 2023
a7fb899
metal : final refactoring and simplification
ggerganov Jun 4, 2023
d8a7486
Revert "ci : disable temporary"
ggerganov Jun 4, 2023
b252acb
metal : add comments
ggerganov Jun 4, 2023
db3db9e
metal : clean-up stuff, fix typos
ggerganov Jun 4, 2023
e33002d
readme : add Metal instructions
ggerganov Jun 4, 2023
324e823
readme : add example for main
ggerganov Jun 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
mtl : add scale kernel
  • Loading branch information
ggerganov committed Jun 1, 2023
commit 0f1c580860e2acbee7c095b113256f69e93869b5
27 changes: 27 additions & 0 deletions examples/mtl/mtl.m
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
id<MTLFunction> function_mul;
id<MTLComputePipelineState> pipeline_mul;

id<MTLFunction> function_scale;
id<MTLComputePipelineState> pipeline_scale;

id<MTLFunction> function_relu;
id<MTLComputePipelineState> pipeline_relu;

Expand Down Expand Up @@ -135,6 +138,10 @@
ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil];
fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul);

ctx->function_scale = [ctx->library newFunctionWithName:@"kernel_scale"];
ctx->pipeline_scale = [ctx->device newComputePipelineStateWithFunction:ctx->function_scale error:nil];
fprintf(stderr, "%s: loaded kernel_scale: %p\n", __func__, (void *) ctx->pipeline_scale);

ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"];
ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil];
fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu);
Expand Down Expand Up @@ -310,6 +317,26 @@ int llama_mtl_eval(

const int64_t n = ggml_nelements(gf->nodes[i]);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SCALE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);

const float scale = *(const float *) gf->nodes[i]->src1->data;

[encoder setComputePipelineState:ctx->pipeline_scale];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];

const int64_t n = ggml_nelements(gf->nodes[i]);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_RELU:
Expand Down
8 changes: 8 additions & 0 deletions examples/mtl/mtl.metal
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ kernel void kernel_mul(
dst[tpig] = src0[tpig] * src1[tpig % ne00];
}

kernel void kernel_scale(
device const float * src0,
device float * dst,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}

kernel void kernel_relu(
device const float * src0,
device float * dst,
Expand Down
8 changes: 4 additions & 4 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1324,10 +1324,6 @@ static bool llama_eval_internal(
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
ggml_set_name(KQ, "KQ");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(KQ, "mtl-check");
}

// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
Expand All @@ -1336,6 +1332,10 @@ static bool llama_eval_internal(
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
ggml_set_name(KQ_scaled, "KQ_scaled");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(KQ_scaled, "mtl-check");
}

// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
Expand Down