@@ -5609,7 +5609,7 @@ inline void ggml_cuda_op_mul_mat_q(
5609
5609
// nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
5610
5610
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
5611
5611
5612
- const int nchannels = buffers_contiguous ? 1 : ne02;
5612
+ const int64_t nchannels = buffers_contiguous ? 1 : ne02;
5613
5613
5614
5614
const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ?
5615
5615
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
@@ -5620,9 +5620,11 @@ inline void ggml_cuda_op_mul_mat_q(
5620
5620
quantize_row_q8_1_cuda (src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, nchannels,
5621
5621
src1_row_stride, src1_channel_stride, cudaStream_main);
5622
5622
5623
- const int row_stride = buffers_contiguous ? ne10 / ggml_blck_size (src0->type ) : nb01 / ggml_type_size (src0->type );
5624
- const int channel_stride_x = buffers_contiguous ? ne10*ne11 / ggml_blck_size (src0->type ) : nb02 / ggml_type_size (src0->type );
5625
- const int channel_stride_y = padded_row_size*ne11 / QK8_1;
5623
+ const int64_t src0_blck_size = ggml_blck_size (src0->type );
5624
+ const int64_t ne10_whole_blck = ne10 % src0_blck_size == 0 ? ne10 : ne10 - ne10 % src0_blck_size + src0_blck_size;
5625
+ const int64_t row_stride = buffers_contiguous ? ne10_whole_blck / ggml_blck_size (src0->type ) : nb01 / ggml_type_size (src0->type );
5626
+ const int64_t channel_stride_x = buffers_contiguous ? ne10_whole_blck*ne11 / ggml_blck_size (src0->type ) : nb02 / ggml_type_size (src0->type );
5627
+ const int64_t channel_stride_y = padded_row_size*ne11 / QK8_1;
5626
5628
5627
5629
switch (src0->type ) {
5628
5630
case GGML_TYPE_Q4_0:
@@ -6221,6 +6223,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
6221
6223
if (src0_is_f32) {
6222
6224
src0_ddf[id] = (float *) ggml_cuda_pool_malloc (row_diff*ne00 * sizeof (float ), &src0_asf[id]);
6223
6225
} else {
6226
+ GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
6224
6227
const int64_t nelements = row_diff*ne00;
6225
6228
const int64_t nelements_padded = ne00 % MATRIX_ROW_PADDING == 0 ?
6226
6229
nelements : nelements - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
0 commit comments