|  | 
| 82 | 82 |     GGML_METAL_KERNEL_TYPE_RMS_NORM, | 
| 83 | 83 |     GGML_METAL_KERNEL_TYPE_GROUP_NORM, | 
| 84 | 84 |     GGML_METAL_KERNEL_TYPE_NORM, | 
|  | 85 | +    GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, | 
|  | 86 | +    GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, | 
| 85 | 87 |     GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, | 
| 86 | 88 |     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, | 
| 87 | 89 |     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, | 
| @@ -542,6 +544,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ | 
| 542 | 544 |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       ctx->support_simdgroup_reduction); | 
| 543 | 545 |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     ctx->support_simdgroup_reduction); | 
| 544 | 546 |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true); | 
|  | 547 | +        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                  ssm_conv_f32,                   true); | 
|  | 548 | +        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                  ssm_scan_f32,                   true); | 
| 545 | 549 |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 ctx->support_simdgroup_reduction); | 
| 546 | 550 |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 ctx->support_simdgroup_reduction); | 
| 547 | 551 |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 ctx->support_simdgroup_reduction); | 
| @@ -803,6 +807,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx | 
| 803 | 807 |                 return false; | 
| 804 | 808 |             } | 
| 805 | 809 |             return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels | 
|  | 810 | +        case GGML_OP_SSM_CONV: | 
|  | 811 | +        case GGML_OP_SSM_SCAN: | 
|  | 812 | +            return true; | 
| 806 | 813 |         case GGML_OP_MUL_MAT: | 
| 807 | 814 |         case GGML_OP_MUL_MAT_ID: | 
| 808 | 815 |             return ctx->support_simdgroup_reduction && | 
| @@ -1538,6 +1545,121 @@ static enum ggml_status ggml_metal_graph_compute( | 
| 1538 | 1545 |                             [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | 
| 1539 | 1546 |                         } | 
| 1540 | 1547 |                     } break; | 
|  | 1548 | +                case GGML_OP_SSM_CONV: | 
|  | 1549 | +                    { | 
|  | 1550 | +                        GGML_ASSERT(src0t == GGML_TYPE_F32); | 
|  | 1551 | +                        GGML_ASSERT(src1t == GGML_TYPE_F32); | 
|  | 1552 | + | 
|  | 1553 | +                        GGML_ASSERT(ggml_is_contiguous(src0)); | 
|  | 1554 | +                        GGML_ASSERT(ggml_is_contiguous(src1)); | 
|  | 1555 | + | 
|  | 1556 | +                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; | 
|  | 1557 | + | 
|  | 1558 | +                        [encoder setComputePipelineState:pipeline]; | 
|  | 1559 | +                        [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0]; | 
|  | 1560 | +                        [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1]; | 
|  | 1561 | +                        [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2]; | 
|  | 1562 | +                        [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3]; | 
|  | 1563 | +                        [encoder setBytes:&ne01    length:sizeof(ne01) atIndex:4]; | 
|  | 1564 | +                        [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:5]; | 
|  | 1565 | +                        [encoder setBytes:&nb00    length:sizeof(nb00) atIndex:6]; | 
|  | 1566 | +                        [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7]; | 
|  | 1567 | +                        [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8]; | 
|  | 1568 | +                        [encoder setBytes:&ne10    length:sizeof(ne10) atIndex:9]; | 
|  | 1569 | +                        [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:10]; | 
|  | 1570 | +                        [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11]; | 
|  | 1571 | +                        [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12]; | 
|  | 1572 | +                        [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13]; | 
|  | 1573 | +                        [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14]; | 
|  | 1574 | +                        [encoder setBytes:&ne2     length:sizeof(ne2)  atIndex:15]; | 
|  | 1575 | +                        [encoder setBytes:&nb0     length:sizeof(nb0)  atIndex:16]; | 
|  | 1576 | +                        [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:17]; | 
|  | 1577 | +                        [encoder setBytes:&nb2     length:sizeof(nb2)  atIndex:18]; | 
|  | 1578 | + | 
|  | 1579 | +                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | 
|  | 1580 | +                    } break; | 
|  | 1581 | +                case GGML_OP_SSM_SCAN: | 
|  | 1582 | +                    { | 
|  | 1583 | +                        struct ggml_tensor * src3 = gf->nodes[i]->src[3]; | 
|  | 1584 | +                        struct ggml_tensor * src4 = gf->nodes[i]->src[4]; | 
|  | 1585 | +                        struct ggml_tensor * src5 = gf->nodes[i]->src[5]; | 
|  | 1586 | + | 
|  | 1587 | +                        GGML_ASSERT(src3); | 
|  | 1588 | +                        GGML_ASSERT(src4); | 
|  | 1589 | +                        GGML_ASSERT(src5); | 
|  | 1590 | + | 
|  | 1591 | +                        size_t offs_src3 = 0; | 
|  | 1592 | +                        size_t offs_src4 = 0; | 
|  | 1593 | +                        size_t offs_src5 = 0; | 
|  | 1594 | + | 
|  | 1595 | +                        id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; | 
|  | 1596 | +                        id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; | 
|  | 1597 | +                        id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; | 
|  | 1598 | + | 
|  | 1599 | +                        const int64_t  ne30 = src3->ne[0]; GGML_UNUSED(ne30); | 
|  | 1600 | +                        const int64_t  ne31 = src3->ne[1]; GGML_UNUSED(ne31); | 
|  | 1601 | + | 
|  | 1602 | +                        const uint64_t nb30 = src3->nb[0]; | 
|  | 1603 | +                        const uint64_t nb31 = src3->nb[1]; | 
|  | 1604 | + | 
|  | 1605 | +                        const int64_t  ne40 = src4->ne[0]; GGML_UNUSED(ne40); | 
|  | 1606 | +                        const int64_t  ne41 = src4->ne[1]; GGML_UNUSED(ne41); | 
|  | 1607 | +                        const int64_t  ne42 = src4->ne[2]; GGML_UNUSED(ne42); | 
|  | 1608 | + | 
|  | 1609 | +                        const uint64_t nb40 = src4->nb[0]; | 
|  | 1610 | +                        const uint64_t nb41 = src4->nb[1]; | 
|  | 1611 | +                        const uint64_t nb42 = src4->nb[2]; | 
|  | 1612 | + | 
|  | 1613 | +                        const int64_t  ne50 = src5->ne[0]; GGML_UNUSED(ne50); | 
|  | 1614 | +                        const int64_t  ne51 = src5->ne[1]; GGML_UNUSED(ne51); | 
|  | 1615 | +                        const int64_t  ne52 = src5->ne[2]; GGML_UNUSED(ne52); | 
|  | 1616 | + | 
|  | 1617 | +                        const uint64_t nb50 = src5->nb[0]; | 
|  | 1618 | +                        const uint64_t nb51 = src5->nb[1]; | 
|  | 1619 | +                        const uint64_t nb52 = src5->nb[2]; | 
|  | 1620 | + | 
|  | 1621 | +                        const int64_t d_state      = ne00; | 
|  | 1622 | +                        const int64_t d_inner      = ne01; | 
|  | 1623 | +                        const int64_t n_seq_tokens = ne11; | 
|  | 1624 | +                        const int64_t n_seqs       = ne02; | 
|  | 1625 | + | 
|  | 1626 | +                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; | 
|  | 1627 | + | 
|  | 1628 | +                        [encoder setComputePipelineState:pipeline]; | 
|  | 1629 | +                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | 
|  | 1630 | +                        [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | 
|  | 1631 | +                        [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; | 
|  | 1632 | +                        [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; | 
|  | 1633 | +                        [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; | 
|  | 1634 | +                        [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; | 
|  | 1635 | +                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6]; | 
|  | 1636 | + | 
|  | 1637 | +                        [encoder setBytes:&d_state      length:sizeof(d_state)      atIndex:7]; | 
|  | 1638 | +                        [encoder setBytes:&d_inner      length:sizeof(d_inner)      atIndex:8]; | 
|  | 1639 | +                        [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; | 
|  | 1640 | +                        [encoder setBytes:&n_seqs       length:sizeof(n_seqs)       atIndex:10]; | 
|  | 1641 | + | 
|  | 1642 | +                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; | 
|  | 1643 | +                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; | 
|  | 1644 | +                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; | 
|  | 1645 | +                        [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; | 
|  | 1646 | +                        [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; | 
|  | 1647 | +                        [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; | 
|  | 1648 | +                        [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; | 
|  | 1649 | +                        [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; | 
|  | 1650 | +                        [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; | 
|  | 1651 | +                        [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; | 
|  | 1652 | +                        [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; | 
|  | 1653 | +                        [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; | 
|  | 1654 | +                        [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; | 
|  | 1655 | +                        [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; | 
|  | 1656 | +                        [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; | 
|  | 1657 | +                        [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; | 
|  | 1658 | +                        [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; | 
|  | 1659 | +                        [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; | 
|  | 1660 | + | 
|  | 1661 | +                        [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | 
|  | 1662 | +                    } break; | 
| 1541 | 1663 |                 case GGML_OP_MUL_MAT: | 
| 1542 | 1664 |                     { | 
| 1543 | 1665 |                         GGML_ASSERT(ne00 == ne10); | 
|  | 
0 commit comments