|
37 | 37 | GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
38 | 38 | GGML_METAL_KERNEL_TYPE_SCALE,
|
39 | 39 | GGML_METAL_KERNEL_TYPE_SCALE_4,
|
| 40 | + GGML_METAL_KERNEL_TYPE_CLAMP, |
40 | 41 | GGML_METAL_KERNEL_TYPE_TANH,
|
41 | 42 | GGML_METAL_KERNEL_TYPE_RELU,
|
42 | 43 | GGML_METAL_KERNEL_TYPE_GELU,
|
@@ -468,6 +469,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
468 | 469 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
469 | 470 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
470 | 471 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
| 472 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); |
471 | 473 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
472 | 474 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
473 | 475 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
@@ -713,6 +715,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
713 | 715 | case GGML_OP_MUL:
|
714 | 716 | case GGML_OP_DIV:
|
715 | 717 | case GGML_OP_SCALE:
|
| 718 | + case GGML_OP_CLAMP: |
716 | 719 | case GGML_OP_SQR:
|
717 | 720 | case GGML_OP_SUM_ROWS:
|
718 | 721 | return true;
|
@@ -1154,6 +1157,25 @@ static enum ggml_status ggml_metal_graph_compute(
|
1154 | 1157 |
|
1155 | 1158 | [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1156 | 1159 | } break;
|
| 1160 | + case GGML_OP_CLAMP: |
| 1161 | + { |
| 1162 | + id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; |
| 1163 | + |
| 1164 | + float min; |
| 1165 | + float max; |
| 1166 | + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); |
| 1167 | + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); |
| 1168 | + |
| 1169 | + [encoder setComputePipelineState:pipeline]; |
| 1170 | + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
| 1171 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; |
| 1172 | + [encoder setBytes:&min length:sizeof(min) atIndex:2]; |
| 1173 | + [encoder setBytes:&max length:sizeof(max) atIndex:3]; |
| 1174 | + |
| 1175 | + const int64_t n = ggml_nelements(dst); |
| 1176 | + |
| 1177 | + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
| 1178 | + } break; |
1157 | 1179 | case GGML_OP_UNARY:
|
1158 | 1180 | switch (ggml_get_unary_op(gf->nodes[i])) {
|
1159 | 1181 | case GGML_UNARY_OP_TANH:
|
|
0 commit comments