Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : Metal inference #1642

Merged
merged 49 commits into from
Jun 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f85020b
mtl : export the LLaMA computation graph
ggerganov May 29, 2023
98c267f
ci : disable temporary
ggerganov May 29, 2023
b23fe8c
mtl : adapt the MNIST example as starter
ggerganov May 29, 2023
a792cbd
mtl : no need for mtl-export tool, add cli arg for main instead
ggerganov May 29, 2023
897d6d8
mtl : export just a small part of the graph for now to make it easier
ggerganov May 29, 2023
248a8c3
mtl : move MSL code into separate file for easy editing
ggerganov May 29, 2023
a8fd9dc
mtl : initial get_rows_q4_0 kernel
ggerganov May 29, 2023
794704e
mtl : confirmed get_rows_q4_0 is working correctly
ggerganov May 30, 2023
72256eb
mtl : add rms_norm kernel + confirm working
ggerganov May 30, 2023
64afc0b
mtl : add mul kernel + confirm working
ggerganov May 30, 2023
2a24994
mtl : initial mul_mat Q4 kernel (wrong results)
ggerganov May 30, 2023
96d0052
mtl : mul_mat fixes (still wrong)
ggerganov May 30, 2023
29bec00
mtl : another mul_mat Q4 (still does not work)
ggerganov May 30, 2023
b2fd06c
mtl : working mul_mat q4
ggerganov May 30, 2023
6af6a05
ggml : fix handling of "view" ops in ggml_graph_import()
ggerganov May 31, 2023
1213af7
mtl : add rope kernel
ggerganov May 31, 2023
7ca81e9
mtl : add reshape and transpose handling
ggerganov May 31, 2023
94ea9e7
ggml : store offset as opt arg for ggml_view_xd() operators
ggerganov Jun 1, 2023
948fcfd
mtl : add cpy kernel + handle view ops
ggerganov Jun 1, 2023
51efb59
mtl : confirm f16 x f32 attention mul mat
ggerganov Jun 1, 2023
0f1c580
mtl : add scale kernel
ggerganov Jun 1, 2023
17a7036
mtl : add diag_mask_inf kernel
ggerganov Jun 1, 2023
17930fb
mtl : fix soft_max kernel
ggerganov Jun 1, 2023
f67c2d8
ggml : update ggml_nbytes() to handle non-contiguous tensors
ggerganov Jun 1, 2023
a266c26
mtl : verify V tensor contents
ggerganov Jun 1, 2023
a0cc3de
mtl : add f32 -> f32 cpy kernel
ggerganov Jun 1, 2023
42dca40
mtl : add silu kernel
ggerganov Jun 1, 2023
fbd3f62
mtl : add non-broadcast mul kernel
ggerganov Jun 1, 2023
9665429
mtl : full GPU inference of the computation graph
ggerganov Jun 1, 2023
f0196a7
mtl : optimize rms_norm and soft_max kernels
ggerganov Jun 1, 2023
e55f7b0
mtl : add f16 mat x f32 vec multiplication kernel
ggerganov Jun 1, 2023
3367146
mtl : fix bug in f16 x f32 mul mat + speed-up computation
ggerganov Jun 2, 2023
847bbfe
mtl : faster mul_mat_q4_0_f32 kernel
ggerganov Jun 2, 2023
70c3387
mtl : fix kernel signature + roll inner loop
ggerganov Jun 2, 2023
b088e14
mtl : more threads for rms_norm + better timing
ggerganov Jun 2, 2023
6276057
mtl : remove printfs from inner loop
ggerganov Jun 2, 2023
03c2d72
mtl : simplify implementation
ggerganov Jun 2, 2023
640a889
mtl : add save/load vocab to ggml file
ggerganov Jun 2, 2023
2f4e9d1
mtl : plug Metal inference into llama.cpp (very quick-n-dirty)
ggerganov Jun 2, 2023
4df2ef3
mtl : make it work with main example
ggerganov Jun 3, 2023
18e482a
mtl : preparing for merge
ggerganov Jun 4, 2023
e4b5222
mtl : clean-up ggml mtl interface + suport scratch / inplace
ggerganov Jun 4, 2023
e26cd6b
mtl : remove temp / debug code
ggerganov Jun 4, 2023
a7fb899
metal : final refactoring and simplification
ggerganov Jun 4, 2023
d8a7486
Revert "ci : disable temporary"
ggerganov Jun 4, 2023
b252acb
metal : add comments
ggerganov Jun 4, 2023
db3db9e
metal : clean-up stuff, fix typos
ggerganov Jun 4, 2023
e33002d
readme : add Metal instructions
ggerganov Jun 4, 2023
324e823
readme : add example for main
ggerganov Jun 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
mtl : add rope kernel
  • Loading branch information
ggerganov committed May 31, 2023
commit 1213af76ceae9e839e1da440f95604c0a013d68d
75 changes: 74 additions & 1 deletion examples/mtl/mtl.m
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@

id<MTLFunction> function_mul_mat_q4_0;
id<MTLComputePipelineState> pipeline_mul_mat_q4_0;

id<MTLFunction> function_rope;
id<MTLComputePipelineState> pipeline_rope;
};

// MSL code
Expand Down Expand Up @@ -148,6 +151,10 @@
ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"];
ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil];
fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0);

ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"];
ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil];
fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope);
}

// MTLBuffer approach
Expand Down Expand Up @@ -250,6 +257,10 @@ int llama_mtl_eval(
fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));

switch (gf->nodes[i]->op) {
case GGML_OP_RESHAPE:
{
// noop
} break;
case GGML_OP_ADD:
{
if (encoder == nil) {
Expand Down Expand Up @@ -453,6 +464,68 @@ int llama_mtl_eval(

[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);

const int64_t ne00 = gf->nodes[i]->src0->ne[0];
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
const int64_t ne03 = gf->nodes[i]->src0->ne[3];

const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
const uint64_t nb03 = gf->nodes[i]->src0->nb[3];

const int64_t ne0 = gf->nodes[i]->ne[0];
const int64_t ne1 = gf->nodes[i]->ne[1];
const int64_t ne2 = gf->nodes[i]->ne[2];
const int64_t ne3 = gf->nodes[i]->ne[3];

const uint64_t nb0 = gf->nodes[i]->nb[0];
const uint64_t nb1 = gf->nodes[i]->nb[1];
const uint64_t nb2 = gf->nodes[i]->nb[2];
const uint64_t nb3 = gf->nodes[i]->nb[3];

const int n_past = ((int32_t *) gf->nodes[i]->src1->data)[0]; // TODO: TMP !!!!!
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];

printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);

[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
GGML_ASSERT(false);
Expand Down Expand Up @@ -486,7 +559,7 @@ int llama_mtl_eval(

{
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
fprintf(stderr, "%s: time elapsed = %f\n", __func__, time_elapsed);
fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
}

// TODO
Expand Down
55 changes: 55 additions & 0 deletions examples/mtl/mtl.metal
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,58 @@ kernel void kernel_mul_mat_q4_0(
dst[r1*ne0 + r0] = sum[0];
}
}

kernel void kernel_rope(
device const void * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i3 = tpig[2];
const int64_t i2 = tpig[1];
const int64_t i1 = tpig[0];

const bool is_neox = mode & 2;
const float theta_scale = pow(10000.0, -2.0f/n_dims);

const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);

float theta = (float)p;

if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);

theta *= theta_scale;

device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

const float x0 = src[0];
const float x1 = src[1];

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
// TODO: implement
}
}
24 changes: 16 additions & 8 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,19 +1270,20 @@ static bool llama_eval_internal(

// self-attention
{
auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(x, "mtl-check");
}
//auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);

// compute Q and K and RoPE them
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
ggml_set_name(Qcur, "Qcur");
ggml_set_name(Kcur, "Kcur");

// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(Qcur, "mtl-check");
}

// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
Expand Down Expand Up @@ -1437,7 +1438,14 @@ static bool llama_eval_internal(
//ggml_graph_compute (ctx0, &gf);

// lets export a smaller graph to get things rolling -- baby steps first
ggml_build_forward_expand(&gf_export, ggml_get_tensor(ctx0, "mtl-check"));
{
struct ggml_tensor * t = ggml_get_tensor(ctx0, "mtl-check");
if (!t) {
fprintf(stderr, "%s: failed to find tensor 'mtl-check'\n", __func__);
exit(1);
}
ggml_build_forward_expand(&gf_export, t);
}

// print
{
Expand Down