Skip to content

Commit

Permalink
Add more guards on compute capability in Marlin kernel (AutoGPTQ#550)
Browse files Browse the repository at this point in the history
add more guards
  • Loading branch information
fxmarty authored Feb 16, 2024
1 parent c1a1ef3 commit b05e059
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions autogptq_extension/marlin/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,16 @@ __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag

// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)
);
#else
assert(0);
#endif
}

// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to
Expand Down Expand Up @@ -181,6 +185,7 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {

// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
if (threadIdx.x == 0) {
int state = -1;
do
Expand All @@ -189,10 +194,14 @@ __device__ inline void barrier_acquire(int* lock, int count) {
while (state != count);
}
__syncthreads();
#else
assert(0);
#endif
}

// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
Expand All @@ -204,6 +213,9 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
asm volatile ("fence.acq_rel.gpu;\n");
asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val));
}
#else
assert(0);
#endif
}


Expand Down

0 comments on commit b05e059

Please sign in to comment.