Skip to content

Commit 3b099bc

Browse files
CUDA: fix MMQ writeback for int8 tensor cores (#8100)
1 parent a818f30 commit 3b099bc

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

ggml-cuda/mmq.cuh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,15 +2054,13 @@ static __device__ __forceinline__ void mmq_write_back_mma(
20542054
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
20552055
#endif // INT8_MMA_AVAILABLE
20562056

2057-
dst += (threadIdx.y % ntx) * mma_C::J*stride;
2058-
20592057
#pragma unroll
20602058
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
20612059
#pragma unroll
20622060
for (int n = 0; n < ntx; ++n) {
20632061
#pragma unroll
20642062
for (int l = 0; l < mma_C::ne; ++l) {
2065-
const int j = j0 + mma_C::get_j(l);
2063+
const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
20662064

20672065
if (j > j_max) {
20682066
continue;

0 commit comments

Comments
 (0)