@@ -964,9 +964,9 @@ void ggml_metal_graph_compute(
964
964
const int64_t nb = ne00;
965
965
966
966
[encoder setComputePipelineState: ctx->pipeline_concat];
967
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
968
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
969
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
967
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
968
+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
969
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
970
970
[encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
971
971
[encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
972
972
[encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
@@ -1029,9 +1029,9 @@ void ggml_metal_graph_compute(
1029
1029
default : GGML_ASSERT (false );
1030
1030
}
1031
1031
}
1032
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1033
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1034
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1032
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1033
+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1034
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1035
1035
[encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
1036
1036
[encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
1037
1037
[encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
@@ -1083,8 +1083,8 @@ void ggml_metal_graph_compute(
1083
1083
[encoder setComputePipelineState: ctx->pipeline_scale];
1084
1084
}
1085
1085
1086
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1087
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1086
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1087
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1088
1088
[encoder setBytes: &scale length: sizeof (scale) atIndex: 2 ];
1089
1089
1090
1090
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
@@ -1094,8 +1094,8 @@ void ggml_metal_graph_compute(
1094
1094
case GGML_UNARY_OP_SILU:
1095
1095
{
1096
1096
[encoder setComputePipelineState: ctx->pipeline_silu];
1097
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1098
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1097
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1098
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1099
1099
1100
1100
const int64_t n = ggml_nelements (dst);
1101
1101
GGML_ASSERT (n % 4 == 0 );
@@ -1105,8 +1105,8 @@ void ggml_metal_graph_compute(
1105
1105
case GGML_UNARY_OP_RELU:
1106
1106
{
1107
1107
[encoder setComputePipelineState: ctx->pipeline_relu];
1108
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1109
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1108
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1109
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1110
1110
1111
1111
const int64_t n = ggml_nelements (dst);
1112
1112
@@ -1115,8 +1115,8 @@ void ggml_metal_graph_compute(
1115
1115
case GGML_UNARY_OP_GELU:
1116
1116
{
1117
1117
[encoder setComputePipelineState: ctx->pipeline_gelu];
1118
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1119
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1118
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1119
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1120
1120
1121
1121
const int64_t n = ggml_nelements (dst);
1122
1122
GGML_ASSERT (n % 4 == 0 );
@@ -1134,8 +1134,8 @@ void ggml_metal_graph_compute(
1134
1134
GGML_ASSERT (ggml_is_contiguous (src0));
1135
1135
1136
1136
[encoder setComputePipelineState: ctx->pipeline_sqr];
1137
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1138
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1137
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1138
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1139
1139
1140
1140
const int64_t n = ggml_nelements (dst);
1141
1141
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
@@ -1145,8 +1145,8 @@ void ggml_metal_graph_compute(
1145
1145
GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
1146
1146
1147
1147
[encoder setComputePipelineState: ctx->pipeline_sum_rows];
1148
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1149
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1148
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1149
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1150
1150
[encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 2 ];
1151
1151
[encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 3 ];
1152
1152
[encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
@@ -1192,9 +1192,9 @@ void ggml_metal_graph_compute(
1192
1192
1193
1193
const float scale = ((float *) dst->op_params )[0 ];
1194
1194
1195
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1196
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1197
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1195
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1196
+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1197
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1198
1198
[encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
1199
1199
[encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
1200
1200
[encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
@@ -1212,8 +1212,8 @@ void ggml_metal_graph_compute(
1212
1212
} else {
1213
1213
[encoder setComputePipelineState: ctx->pipeline_diag_mask_inf];
1214
1214
}
1215
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1216
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1215
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1216
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1217
1217
[encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 2 ];
1218
1218
[encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 3 ];
1219
1219
[encoder setBytes: &n_past length: sizeof (int ) atIndex: 4 ];
@@ -1286,9 +1286,9 @@ void ggml_metal_graph_compute(
1286
1286
case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q6_K_f32]; break ;
1287
1287
default : GGML_ASSERT (false && " MUL MAT-MAT not implemented" );
1288
1288
}
1289
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1290
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1291
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1289
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1290
+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1291
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1292
1292
[encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
1293
1293
[encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
1294
1294
[encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 5 ];
@@ -1403,9 +1403,9 @@ void ggml_metal_graph_compute(
1403
1403
}
1404
1404
};
1405
1405
1406
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1407
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1408
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1406
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1407
+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1408
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1409
1409
[encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
1410
1410
[encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
1411
1411
[encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
@@ -1511,9 +1511,9 @@ void ggml_metal_graph_compute(
1511
1511
case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_q6_K_f32]; break ;
1512
1512
default : GGML_ASSERT (false && " MUL_MAT_ID not implemented" );
1513
1513
}
1514
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1515
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1516
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1514
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1515
+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1516
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1517
1517
[encoder setBytes: &ne20 length: sizeof (ne20) atIndex: 3 ];
1518
1518
[encoder setBytes: &ne22 length: sizeof (ne22) atIndex: 4 ];
1519
1519
[encoder setBytes: &nb21 length: sizeof (nb21) atIndex: 5 ];
@@ -1559,9 +1559,9 @@ void ggml_metal_graph_compute(
1559
1559
default : GGML_ASSERT (false && " not implemented" );
1560
1560
}
1561
1561
1562
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1563
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1564
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1562
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1563
+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1564
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1565
1565
[encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 3 ];
1566
1566
[encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 4 ];
1567
1567
[encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 5 ];
@@ -1584,8 +1584,8 @@ void ggml_metal_graph_compute(
1584
1584
}
1585
1585
1586
1586
[encoder setComputePipelineState: ctx->pipeline_rms_norm];
1587
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1588
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1587
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1588
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1589
1589
[encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
1590
1590
[encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 3 ];
1591
1591
[encoder setBytes: &eps length: sizeof ( float ) atIndex: 4 ];
@@ -1603,8 +1603,8 @@ void ggml_metal_graph_compute(
1603
1603
const int nth = MIN (256 , ne00);
1604
1604
1605
1605
[encoder setComputePipelineState: ctx->pipeline_norm];
1606
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1607
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1606
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1607
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1608
1608
[encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
1609
1609
[encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 3 ];
1610
1610
[encoder setBytes: &eps length: sizeof ( float ) atIndex: 4 ];
@@ -1630,8 +1630,8 @@ void ggml_metal_graph_compute(
1630
1630
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_heads_log2_floor);
1631
1631
1632
1632
[encoder setComputePipelineState: ctx->pipeline_alibi_f32];
1633
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1634
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1633
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1634
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1635
1635
[encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
1636
1636
[encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
1637
1637
[encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
@@ -1680,9 +1680,9 @@ void ggml_metal_graph_compute(
1680
1680
default : GGML_ASSERT (false );
1681
1681
};
1682
1682
1683
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1684
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1685
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1683
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1684
+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1685
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1686
1686
[encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 3 ];
1687
1687
[encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 4 ];
1688
1688
[encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 5 ];
@@ -1748,8 +1748,8 @@ void ggml_metal_graph_compute(
1748
1748
default : GGML_ASSERT (false );
1749
1749
};
1750
1750
1751
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
1752
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1751
+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
1752
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1753
1753
[encoder setBytes: &ofs0 length: sizeof ( int32_t ) atIndex: 2 ];
1754
1754
[encoder setBytes: &ofs1 length: sizeof ( int32_t ) atIndex: 3 ];
1755
1755
[encoder setBytes: &IW length: sizeof ( int32_t ) atIndex: 4 ];
@@ -1779,8 +1779,8 @@ void ggml_metal_graph_compute(
1779
1779
default : GGML_ASSERT (false );
1780
1780
};
1781
1781
1782
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1783
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1782
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1783
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1784
1784
[encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
1785
1785
1786
1786
[encoder dispatchThreadgroups: MTLSizeMake (1 , nrows, 1 ) threadsPerThreadgroup: MTLSizeMake (ne00, 1 , 1 )];
@@ -1820,8 +1820,8 @@ void ggml_metal_graph_compute(
1820
1820
default : GGML_ASSERT (false && " not implemented" );
1821
1821
}
1822
1822
1823
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1824
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1823
+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1824
+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1825
1825
[encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
1826
1826
[encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
1827
1827
[encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
0 commit comments