Skip to content

Commit

Permalink
metal : support bcast add & dup & cont op (#2323)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Jul 23, 2023
1 parent d2a4366 commit 83a00ce
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
12 changes: 11 additions & 1 deletion ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
id<MTLComputePipelineState> pipeline_##name

GGML_METAL_DECL_KERNEL(add);
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
GGML_METAL_DECL_KERNEL(mul);
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
GGML_METAL_DECL_KERNEL(scale);
Expand Down Expand Up @@ -157,6 +158,7 @@ @implementation GGMLMetalClass
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);

GGML_METAL_ADD_KERNEL(add);
GGML_METAL_ADD_KERNEL(add_row);
GGML_METAL_ADD_KERNEL(mul);
GGML_METAL_ADD_KERNEL(mul_row);
GGML_METAL_ADD_KERNEL(scale);
Expand Down Expand Up @@ -464,10 +466,16 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}

[encoder setComputePipelineState:ctx->pipeline_add];
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_add_row];
} else {
[encoder setComputePipelineState:ctx->pipeline_add];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];

const int64_t n = ggml_nelements(dst);

Expand Down Expand Up @@ -919,7 +927,9 @@ void ggml_metal_graph_compute(

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
Expand Down
11 changes: 11 additions & 0 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ kernel void kernel_add(
dst[tpig] = src0[tpig] + src1[tpig];
}

// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % ne00];
}

kernel void kernel_mul(
device const float * src0,
device const float * src1,
Expand Down

0 comments on commit 83a00ce

Please sign in to comment.