Skip to content

Fix Metal API validation errors #4374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 50 additions & 50 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -964,9 +964,9 @@ void ggml_metal_graph_compute(
const int64_t nb = ne00;

[encoder setComputePipelineState:ctx->pipeline_concat];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
Expand Down Expand Up @@ -1029,9 +1029,9 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false);
}
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
Expand Down Expand Up @@ -1083,8 +1083,8 @@ void ggml_metal_graph_compute(
[encoder setComputePipelineState:ctx->pipeline_scale];
}

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
Expand All @@ -1094,8 +1094,8 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_SILU:
{
[encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
Expand All @@ -1105,8 +1105,8 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_RELU:
{
[encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);

Expand All @@ -1115,8 +1115,8 @@ void ggml_metal_graph_compute(
case GGML_UNARY_OP_GELU:
{
[encoder setComputePipelineState:ctx->pipeline_gelu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
Expand All @@ -1134,8 +1134,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ggml_is_contiguous(src0));

[encoder setComputePipelineState:ctx->pipeline_sqr];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
Expand All @@ -1145,8 +1145,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));

[encoder setComputePipelineState:ctx->pipeline_sum_rows];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
Expand Down Expand Up @@ -1192,9 +1192,9 @@ void ggml_metal_graph_compute(

const float scale = ((float *) dst->op_params)[0];

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
Expand All @@ -1212,8 +1212,8 @@ void ggml_metal_graph_compute(
} else {
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
Expand Down Expand Up @@ -1286,9 +1286,9 @@ void ggml_metal_graph_compute(
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
Expand Down Expand Up @@ -1403,9 +1403,9 @@ void ggml_metal_graph_compute(
}
};

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
Expand Down Expand Up @@ -1511,9 +1511,9 @@ void ggml_metal_graph_compute(
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
Expand Down Expand Up @@ -1559,9 +1559,9 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented");
}

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
Expand All @@ -1584,8 +1584,8 @@ void ggml_metal_graph_compute(
}

[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
Expand All @@ -1603,8 +1603,8 @@ void ggml_metal_graph_compute(
const int nth = MIN(256, ne00);

[encoder setComputePipelineState:ctx->pipeline_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
Expand All @@ -1630,8 +1630,8 @@ void ggml_metal_graph_compute(
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);

[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
Expand Down Expand Up @@ -1680,9 +1680,9 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false);
};

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
Expand Down Expand Up @@ -1748,8 +1748,8 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false);
};

[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
Expand Down Expand Up @@ -1779,8 +1779,8 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false);
};

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];

[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
Expand Down Expand Up @@ -1820,8 +1820,8 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented");
}

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
Expand Down