Skip to content

Commit 8def8ba

Browse files
committed
Fix Metal API validation errors
1 parent afce6fa commit 8def8ba

File tree

1 file changed

+50
-50
lines changed

1 file changed

+50
-50
lines changed

ggml-metal.m

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -964,9 +964,9 @@ void ggml_metal_graph_compute(
964964
const int64_t nb = ne00;
965965

966966
[encoder setComputePipelineState:ctx->pipeline_concat];
967-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
968-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
969-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
967+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
968+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
969+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
970970
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
971971
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
972972
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@@ -1029,9 +1029,9 @@ void ggml_metal_graph_compute(
10291029
default: GGML_ASSERT(false);
10301030
}
10311031
}
1032-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1033-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1034-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1032+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1033+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1034+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
10351035
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
10361036
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
10371037
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@@ -1083,8 +1083,8 @@ void ggml_metal_graph_compute(
10831083
[encoder setComputePipelineState:ctx->pipeline_scale];
10841084
}
10851085

1086-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1087-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1086+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1087+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
10881088
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
10891089

10901090
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@@ -1094,8 +1094,8 @@ void ggml_metal_graph_compute(
10941094
case GGML_UNARY_OP_SILU:
10951095
{
10961096
[encoder setComputePipelineState:ctx->pipeline_silu];
1097-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1098-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1097+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1098+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
10991099

11001100
const int64_t n = ggml_nelements(dst);
11011101
GGML_ASSERT(n % 4 == 0);
@@ -1105,8 +1105,8 @@ void ggml_metal_graph_compute(
11051105
case GGML_UNARY_OP_RELU:
11061106
{
11071107
[encoder setComputePipelineState:ctx->pipeline_relu];
1108-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1109-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1108+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1109+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
11101110

11111111
const int64_t n = ggml_nelements(dst);
11121112

@@ -1115,8 +1115,8 @@ void ggml_metal_graph_compute(
11151115
case GGML_UNARY_OP_GELU:
11161116
{
11171117
[encoder setComputePipelineState:ctx->pipeline_gelu];
1118-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1119-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1118+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1119+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
11201120

11211121
const int64_t n = ggml_nelements(dst);
11221122
GGML_ASSERT(n % 4 == 0);
@@ -1134,8 +1134,8 @@ void ggml_metal_graph_compute(
11341134
GGML_ASSERT(ggml_is_contiguous(src0));
11351135

11361136
[encoder setComputePipelineState:ctx->pipeline_sqr];
1137-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1138-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1137+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1138+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
11391139

11401140
const int64_t n = ggml_nelements(dst);
11411141
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@@ -1145,8 +1145,8 @@ void ggml_metal_graph_compute(
11451145
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
11461146

11471147
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
1148-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1149-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1148+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1149+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
11501150
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
11511151
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
11521152
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
@@ -1192,9 +1192,9 @@ void ggml_metal_graph_compute(
11921192

11931193
const float scale = ((float *) dst->op_params)[0];
11941194

1195-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1196-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1197-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1195+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1196+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1197+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
11981198
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
11991199
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
12001200
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@@ -1212,8 +1212,8 @@ void ggml_metal_graph_compute(
12121212
} else {
12131213
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
12141214
}
1215-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1216-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1215+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1216+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
12171217
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
12181218
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
12191219
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
@@ -1286,9 +1286,9 @@ void ggml_metal_graph_compute(
12861286
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
12871287
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
12881288
}
1289-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1290-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1291-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1289+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1290+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1291+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
12921292
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
12931293
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
12941294
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
@@ -1403,9 +1403,9 @@ void ggml_metal_graph_compute(
14031403
}
14041404
};
14051405

1406-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1407-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1408-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1406+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1407+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1408+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
14091409
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
14101410
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
14111411
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
@@ -1511,9 +1511,9 @@ void ggml_metal_graph_compute(
15111511
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
15121512
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
15131513
}
1514-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1514+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
15171517
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
15181518
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
15191519
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
@@ -1559,9 +1559,9 @@ void ggml_metal_graph_compute(
15591559
default: GGML_ASSERT(false && "not implemented");
15601560
}
15611561

1562-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1563-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1564-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1562+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1563+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1564+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
15651565
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
15661566
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
15671567
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
@@ -1584,8 +1584,8 @@ void ggml_metal_graph_compute(
15841584
}
15851585

15861586
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
1587-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1588-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1587+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1588+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
15891589
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
15901590
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
15911591
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
@@ -1603,8 +1603,8 @@ void ggml_metal_graph_compute(
16031603
const int nth = MIN(256, ne00);
16041604

16051605
[encoder setComputePipelineState:ctx->pipeline_norm];
1606-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1607-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1606+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1607+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
16081608
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
16091609
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
16101610
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
@@ -1630,8 +1630,8 @@ void ggml_metal_graph_compute(
16301630
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
16311631

16321632
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
1633-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1634-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1633+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1634+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
16351635
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
16361636
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
16371637
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
@@ -1680,9 +1680,9 @@ void ggml_metal_graph_compute(
16801680
default: GGML_ASSERT(false);
16811681
};
16821682

1683-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1684-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1685-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1683+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1684+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1685+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
16861686
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
16871687
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
16881688
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
@@ -1748,8 +1748,8 @@ void ggml_metal_graph_compute(
17481748
default: GGML_ASSERT(false);
17491749
};
17501750

1751-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1752-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1751+
if (id_src1) [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1752+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
17531753
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
17541754
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
17551755
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
@@ -1779,8 +1779,8 @@ void ggml_metal_graph_compute(
17791779
default: GGML_ASSERT(false);
17801780
};
17811781

1782-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1783-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1782+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1783+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
17841784
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
17851785

17861786
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
@@ -1820,8 +1820,8 @@ void ggml_metal_graph_compute(
18201820
default: GGML_ASSERT(false && "not implemented");
18211821
}
18221822

1823-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1824-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1823+
if (id_src0) [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1824+
if (id_dst) [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
18251825
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
18261826
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
18271827
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];

0 commit comments

Comments
 (0)