Skip to content

Commit 933c5be

Browse files
authored
whisper : support ggml_conv with CUDA and Metal (#1473)
* ggml : add CUDA support for ggml_conv * whisper : remove ggml_repeat for conv bias + single backend * cuda : fix im2col kernel * metal : add im2col support + mul mat-vec f16 x f16 * bench-all : add q4 models
1 parent c99e290 commit 933c5be

File tree

7 files changed

+604
-1225
lines changed

7 files changed

+604
-1225
lines changed

extra/bench-all.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ else
1818
fi
1919

2020
models=( \
21-
"tiny" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
22-
"base" "base-q5_0" "base-q5_1" "base-q8_0" \
23-
"small" "small-q5_0" "small-q5_1" "small-q8_0" \
24-
"medium" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
25-
"large" "large-q5_0" "large-q5_1" "large-q8_0" \
21+
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
22+
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
23+
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
24+
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
25+
"large" "large-q4_0" "large-q4_1" "large-q5_0" "large-q5_1" "large-q8_0" \
2626
)
2727

2828
if [ "$encoder_only" -eq 0 ]; then

ggml-cuda.cu

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4476,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
44764476
*dsti = __float2half(*xi);
44774477
}
44784478

4479+
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
4480+
const half * xi = (const half *) cxi;
4481+
half * dsti = (half *) cdsti;
4482+
4483+
*dsti = *xi;
4484+
}
4485+
44794486
template <cpy_kernel_t cpy_1>
44804487
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
44814488
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4729,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
47294736
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
47304737
}
47314738

4739+
static __global__ void im2col_f32_f16(
4740+
const float * x, half * dst,
4741+
int ofs0, int ofs1, int IW, int IH, int CHW,
4742+
int s0, int s1, int p0, int p1, int d0, int d1) {
4743+
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
4744+
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
4745+
4746+
const int offset_dst =
4747+
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
4748+
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
4749+
4750+
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
4751+
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
4752+
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
4753+
} else {
4754+
dst[offset_dst] = __float2half(0.0f);
4755+
}
4756+
}
4757+
47324758
template<int qk, int qr, dequantize_kernel_t dq>
47334759
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
47344760
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@@ -5618,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda(
56185644
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
56195645
}
56205646

5647+
static void ggml_cpy_f16_f16_cuda(
5648+
const char * cx, char * cdst, const int ne,
5649+
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5650+
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
5651+
5652+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
5653+
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
5654+
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5655+
}
5656+
56215657
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
56225658
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
56235659
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -5701,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
57015737
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
57025738
}
57035739

5740+
static void im2col_f32_f16_cuda(const float * x, half * dst,
5741+
int OH, int IW, int IH, int OW, int IC,
5742+
int KH, int KW, int N, int ofs0, int ofs1,
5743+
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
5744+
dim3 block_nums(IC, OH, OW);
5745+
dim3 block_dims(N, KH, KW);
5746+
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
5747+
}
5748+
57045749
// buffer pool for cuda
57055750
#define MAX_CUDA_BUFFERS 256
57065751

@@ -6483,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
64836528
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
64846529
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
64856530
}
6486-
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6531+
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
64876532
size_t dst_f16_as = 0;
64886533
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
64896534

@@ -6659,6 +6704,45 @@ inline void ggml_cuda_op_alibi(
66596704
(void) src1_dd;
66606705
}
66616706

6707+
inline void ggml_cuda_op_im2col(
6708+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6709+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6710+
6711+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
6712+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
6713+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
6714+
6715+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
6716+
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
6717+
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
6718+
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
6719+
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
6720+
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
6721+
6722+
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
6723+
6724+
const int64_t N = src1->ne[is_2D ? 3 : 2];
6725+
const int64_t IC = src1->ne[is_2D ? 2 : 1];
6726+
const int64_t IH = is_2D ? src1->ne[1] : 1;
6727+
const int64_t IW = src1->ne[0];
6728+
6729+
const int64_t KH = is_2D ? src0->ne[1] : 1;
6730+
const int64_t KW = src0->ne[0];
6731+
6732+
const int64_t OH = is_2D ? dst->ne[2] : 1;
6733+
const int64_t OW = dst->ne[1];
6734+
6735+
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
6736+
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
6737+
6738+
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
6739+
OH, IW, IH, OW, IC, KH, KW, N,
6740+
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
6741+
6742+
(void) src0;
6743+
(void) src0_dd;
6744+
}
6745+
66626746
inline void ggml_cuda_op_diag_mask_inf(
66636747
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
66646748
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7549,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
75497633
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
75507634
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
75517635
ne10, ne11, nb10, nb11, nb12, main_stream);
7636+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
7637+
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7638+
ne10, ne11, nb10, nb11, nb12, main_stream);
75527639
} else {
75537640
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
75547641
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7580,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
75807667
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
75817668
}
75827669

7670+
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7671+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
7672+
}
7673+
75837674
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
75847675
(void) src0;
75857676
(void) src1;
@@ -7943,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
79438034
case GGML_OP_ALIBI:
79448035
func = ggml_cuda_alibi;
79458036
break;
8037+
case GGML_OP_IM2COL:
8038+
func = ggml_cuda_im2col;
8039+
break;
79468040
default:
79478041
return false;
79488042
}

ggml-metal.m

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
GGML_METAL_DECL_KERNEL(rms_norm);
8787
GGML_METAL_DECL_KERNEL(norm);
8888
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
89+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
8990
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
9091
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
9192
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -114,6 +115,7 @@
114115
GGML_METAL_DECL_KERNEL(rope_f32);
115116
GGML_METAL_DECL_KERNEL(rope_f16);
116117
GGML_METAL_DECL_KERNEL(alibi_f32);
118+
GGML_METAL_DECL_KERNEL(im2col_f16);
117119
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
118120
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
119121
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@@ -287,6 +289,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
287289
GGML_METAL_ADD_KERNEL(rms_norm);
288290
GGML_METAL_ADD_KERNEL(norm);
289291
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
292+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
290293
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
291294
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
292295
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -317,6 +320,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
317320
GGML_METAL_ADD_KERNEL(rope_f32);
318321
GGML_METAL_ADD_KERNEL(rope_f16);
319322
GGML_METAL_ADD_KERNEL(alibi_f32);
323+
GGML_METAL_ADD_KERNEL(im2col_f16);
320324
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
321325
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
322326
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
386390
GGML_METAL_DEL_KERNEL(rms_norm);
387391
GGML_METAL_DEL_KERNEL(norm);
388392
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
393+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
389394
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
390395
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
391396
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
416421
GGML_METAL_DEL_KERNEL(rope_f32);
417422
GGML_METAL_DEL_KERNEL(rope_f16);
418423
GGML_METAL_DEL_KERNEL(alibi_f32);
424+
GGML_METAL_DEL_KERNEL(im2col_f16);
419425
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
420426
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
421427
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@@ -1139,20 +1145,26 @@ void ggml_metal_graph_compute(
11391145
switch (src0t) {
11401146
case GGML_TYPE_F32:
11411147
{
1148+
GGML_ASSERT(src1t == GGML_TYPE_F32);
11421149
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
11431150
nrows = 4;
11441151
} break;
11451152
case GGML_TYPE_F16:
11461153
{
11471154
nth0 = 32;
11481155
nth1 = 1;
1149-
if (ne11 * ne12 < 4) {
1150-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1151-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1152-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1153-
nrows = ne11;
1156+
if (src1t == GGML_TYPE_F32) {
1157+
if (ne11 * ne12 < 4) {
1158+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1159+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1160+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1161+
nrows = ne11;
1162+
} else {
1163+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1164+
nrows = 4;
1165+
}
11541166
} else {
1155-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1167+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
11561168
nrows = 4;
11571169
}
11581170
} break;
@@ -1464,6 +1476,58 @@ void ggml_metal_graph_compute(
14641476

14651477
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
14661478
} break;
1479+
case GGML_OP_IM2COL:
1480+
{
1481+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
1482+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
1483+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
1484+
1485+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
1486+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
1487+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
1488+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
1489+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
1490+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
1491+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1492+
1493+
const int32_t N = src1->ne[is_2D ? 3 : 2];
1494+
const int32_t IC = src1->ne[is_2D ? 2 : 1];
1495+
const int32_t IH = is_2D ? src1->ne[1] : 1;
1496+
const int32_t IW = src1->ne[0];
1497+
1498+
const int32_t KH = is_2D ? src0->ne[1] : 1;
1499+
const int32_t KW = src0->ne[0];
1500+
1501+
const int32_t OH = is_2D ? dst->ne[2] : 1;
1502+
const int32_t OW = dst->ne[1];
1503+
1504+
const int32_t CHW = IC * KH * KW;
1505+
1506+
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
1507+
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1508+
1509+
switch (src0->type) {
1510+
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
1511+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
1512+
default: GGML_ASSERT(false);
1513+
};
1514+
1515+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1516+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1517+
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
1518+
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
1519+
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
1520+
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
1521+
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
1522+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
1523+
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
1524+
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
1525+
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
1526+
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
1527+
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
1528+
1529+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1530+
} break;
14671531
case GGML_OP_DUP:
14681532
case GGML_OP_CPY:
14691533
case GGML_OP_CONT:

0 commit comments

Comments
 (0)