Skip to content

Commit ac1e322

Browse files
author
Iwan Kawrakow
committed
This seems to fix it.
1 parent 2c18ef1 commit ac1e322

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,7 +2244,8 @@ static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n
22442244

22452245
for (int i = 0; i < (int)n_as; ++i) cum_moe_counts[i] -= moe_counts[i];
22462246

2247-
CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream));
2247+
CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(),
2248+
cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream));
22482249
CUDA_CHECK(cudaStreamSynchronize(stream));
22492250

22502251
}
@@ -2254,6 +2255,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
22542255
const ggml_tensor * src1 = dst->src[1];
22552256
const ggml_tensor * ids = dst->src[2];
22562257

2258+
CUDA_CHECK(cudaMemset((char *)dst->data, 0, ggml_nbytes(dst)));
2259+
22572260
if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 &&
22582261
ggml_is_quantized(src0->type) &&
22592262
ggml_backend_buffer_is_cuda(src0->buffer) &&
@@ -2519,13 +2522,16 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
25192522
auto local_src0 = *next->src[0];
25202523
local_src0.ne[2] = local_src0.ne[3] = 1;
25212524

2525+
CUDA_CHECK(cudaMemset(next->data, 0, ggml_nbytes(next)));
2526+
25222527
ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next,
25232528
(const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data,
25242529
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
25252530
CUDA_CHECK(cudaGetLastError());
25262531

25272532
return true;
25282533
} else {
2534+
CUDA_CHECK(cudaMemset(dst->data, 0, ggml_nbytes(dst)));
25292535
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
25302536
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
25312537
CUDA_CHECK(cudaGetLastError());
@@ -2534,7 +2540,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
25342540
}
25352541
}
25362542

2537-
25382543
GGML_TENSOR_BINARY_OP_LOCALS
25392544

25402545
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers");

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -150,20 +150,21 @@ static __global__ void mul_mat_vec_q(
150150
char * cdst = (char *)dst + i2*nb2;
151151
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
152152
if (i02 < 0) {
153-
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
154-
constexpr int rows_per_cuda_block = 1;
155-
#else
156-
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
157-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
158-
const int row0 = rows_per_cuda_block*blockIdx.x;
159-
if (threadIdx.y == 0) {
160-
dst = (float *)cdst;
161-
for (int j = 0; j < ncols_y; ++j) {
162-
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
163-
dst[j*nrows_dst + row0 + threadIdx.x] = 0;
164-
}
165-
}
166-
}
153+
// We clar the buffer via cudaMemset instead
154+
//#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
155+
// constexpr int rows_per_cuda_block = 1;
156+
//#else
157+
// constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
158+
//#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
159+
// const int row0 = rows_per_cuda_block*blockIdx.x;
160+
// if (threadIdx.y == 0) {
161+
// dst = (float *)cdst;
162+
// for (int j = 0; j < ncols_y; ++j) {
163+
// if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
164+
// dst[j*nrows_dst + row0 + threadIdx.x] = 0;
165+
// }
166+
// }
167+
// }
167168
return;
168169
}
169170
const char * cx = (const char *)vx + i02*nb02;

0 commit comments

Comments
 (0)