Skip to content

Commit 83798e4

Browse files
dave-fldave-fl
authored andcommitted
Added support for GGML_OP_CLAMP in Metal (ggml-org#6662)
* Added support for GGML_OP_CLAMP in Metal * Corrected size --------- Co-authored-by: dave-fl <dave@Davids-MacBook-Pro.local>
1 parent 4780ea1 commit 83798e4

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

ggml-metal.m

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
GGML_METAL_KERNEL_TYPE_DIV_ROW,
3838
GGML_METAL_KERNEL_TYPE_SCALE,
3939
GGML_METAL_KERNEL_TYPE_SCALE_4,
40+
GGML_METAL_KERNEL_TYPE_CLAMP,
4041
GGML_METAL_KERNEL_TYPE_TANH,
4142
GGML_METAL_KERNEL_TYPE_RELU,
4243
GGML_METAL_KERNEL_TYPE_GELU,
@@ -468,6 +469,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
468469
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
469470
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
470471
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
472+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
471473
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
472474
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
473475
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
@@ -713,6 +715,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
713715
case GGML_OP_MUL:
714716
case GGML_OP_DIV:
715717
case GGML_OP_SCALE:
718+
case GGML_OP_CLAMP:
716719
case GGML_OP_SQR:
717720
case GGML_OP_SUM_ROWS:
718721
return true;
@@ -1154,6 +1157,25 @@ static enum ggml_status ggml_metal_graph_compute(
11541157

11551158
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
11561159
} break;
1160+
case GGML_OP_CLAMP:
1161+
{
1162+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1163+
1164+
float min;
1165+
float max;
1166+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1167+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1168+
1169+
[encoder setComputePipelineState:pipeline];
1170+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1171+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1172+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
1173+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
1174+
1175+
const int64_t n = ggml_nelements(dst);
1176+
1177+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1178+
} break;
11571179
case GGML_OP_UNARY:
11581180
switch (ggml_get_unary_op(gf->nodes[i])) {
11591181
case GGML_UNARY_OP_TANH:

ggml-metal.metal

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,15 @@ kernel void kernel_scale_4(
213213
dst[tpig] = src0[tpig] * scale;
214214
}
215215

216+
kernel void kernel_clamp(
217+
device const float * src0,
218+
device float * dst,
219+
constant float & min,
220+
constant float & max,
221+
uint tpig[[thread_position_in_grid]]) {
222+
dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
223+
}
224+
216225
kernel void kernel_relu(
217226
device const float * src0,
218227
device float * dst,

0 commit comments

Comments
 (0)