From 6bef412a46aa115574e0e5feeca890c20b59f0c8 Mon Sep 17 00:00:00 2001 From: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Date: Mon, 26 Aug 2024 19:39:29 +0530 Subject: [PATCH] Cuda source cleanup , refactor and fixes (#1328) * remove kcompress * fix initial template call * fix function name * remove vector load * cleanup reduce & rearrange * format --- csrc/kernels.cu | 201 ++++++++--------------------------------------- csrc/kernels.cuh | 1 - csrc/ops.cu | 41 +++++----- 3 files changed, 50 insertions(+), 193 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index e4d459961..0f8ec4b7e 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -20,6 +20,7 @@ #define NUM 4 #define NUM_BLOCK 4096 +__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { @@ -462,50 +463,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran } } -template -__device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper) -{ - int lower_pivot = QUADRANT*16-1 - 0; - int pivot = QUADRANT*16-1 + 16; - int upper_pivot = QUADRANT*16-1 + 31; - - float val = midpoint; - - // i>>=1 = {32, 16, 8, 4, 2, 1} - for(int i = 16; i > 0; i>>=1) - { - if(x > val) - { - lower_pivot = pivot; - lower = val; - pivot+=i; - } - else - { - upper_pivot = pivot; - upper = val; - pivot-=i; - } - val = smem_code[pivot]; - } - - if(x > val) - { - midpoint = (upper+val)*0.5f; - if(x > midpoint) - return upper_pivot; - else - return pivot; - } - else - { - midpoint = (lower+val)*0.5f; - if(x < midpoint) - return lower_pivot; - else - return pivot; - } -} __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) { @@ -519,86 +476,6 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index } } -template -__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n) -{ - typedef cub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage; - typedef cub::BlockLoad LoadT; - __shared__ typename LoadT::TempStorage loadt; - - const int warp_idx = threadIdx.x/32; - const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE); - - // BLOCK_SIZE/32 == number of warps - __shared__ int smem_max_indices[8*BLOCK_SIZE/32]; - __shared__ float smem_max_values[8*BLOCK_SIZE/32]; - - T values[8]; - T max1 = -64000.0f; - T max2 = -64000.0f; - int max_idx1 = -1; - int max_idx2 = -1; - int sign1 = -1; - int sign2 = -1; - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - - LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f); - #pragma unroll 8 - for(int i = 0; i < 8; i++) - { - T absval = fabsf(values[i]); - if(absval > max1) - { - max1 = values[i]; - sign1 = signbit(values[i]); - max_idx1 = 8*threadIdx.x + i; - } - else if(absval > max2) - { - max2 = values[i]; - sign2 = signbit(values[i]); - max_idx2 = 8*threadIdx.x + i; - } - } - - float warp_max; - for(int i = 0; i < 8; i++) - { - // 3. do warp reduction + broadcast back - warp_max = WarpReduce(temp_storage).Reduce(max1, cub::Max()); - warp_max = cub::ShuffleIndex<32>(warp_max, 0, 0xffffffff); - - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - if(warp_max == max1) - { - smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; - smem_max_indices[warp_idx*8 + i] = max_idx1; - - sign1 = sign2; - max1 = max2; - max_idx1 = max_idx2; - - max2 = -64000.0f; - } - __syncwarp(); - } - - if(threadIdx.x % 32 < 8) - { - // offset: 8 values per 256 input values - // - int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8; - } - -} - #define THREADS_ESTIMATE 512 #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 @@ -1560,7 +1437,8 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; switch(OPTIMIZER) { - case MOMENTUM: + case ADAGRAD: + case MOMENTUM: if(step == 1) s1_vals[j] = (float)g_vals[j]; else @@ -1663,6 +1541,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, if(weight_decay > 0.0f) { switch(OPTIMIZER) { + case ADAGRAD: case MOMENTUM: case RMSPROP: g_val += ((float)p_vals[j])*weight_decay; @@ -1675,8 +1554,8 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; - switch(OPTIMIZER) - { + switch(OPTIMIZER){ + case ADAGRAD: case MOMENTUM: if(step == 1) s1_vals[j] = g_vals[j]; @@ -3055,45 +2934,6 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * } } - -//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) -//{ -//// element-wise kernel -//// 1. Load batch x k into registers -//// 2. Load k x k into registers -//// 3. dequantize and store in second pair of k x k -//// 4. matmul -//// 5. sum with cub -//// 6. store outputs -//// TC kernel -//// use k warps per thread block -//// 1. threadblock use read-only cache to read in register tile for A into shared memory -//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments -//// 3. each warp reads a segment of values 16x32 from B -//// 4. do dequantization from register of B into second pair of registers -//// 5. store (4) into fragment -//// 6. matmul aggregate into fragment C -//// 7. aggregate files of C into shared memory block C -//// 8. sum (7) -//// 9. write outputs to matmul output matrix -//} - -template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) -{ - if(limit_base + ITEMS <= limit) - reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; - else - { - for(int k = 0; k < ITEMS; k++) - { - if(limit_base + k < limit) - local[k] = buffer[idx+k]; - else - local[k] = (T)zero_value; - } - } -} - #define WARPS 3 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3311,13 +3151,28 @@ template __device__ void printnonzero(T *A, int num_values, const c printf("%s %i %f\n", strval, i, (float)A[i]); } -template __device__ void printnonzero(float *A, int num_values, const char*strval); -template __device__ void printnonzero(half *A, int num_values, const char*strval); -__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { + //// element-wise kernel + //// 1. Load batch x k into registers + //// 2. Load k x k into registers + //// 3. dequantize and store in second pair of k x k + //// 4. matmul + //// 5. sum with cub + //// 6. store outputs + //// TC kernel + //// use k warps per thread block + //// 1. threadblock use read-only cache to read in register tile for A into shared memory + //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments + //// 3. each warp reads a segment of values 16x32 from B + //// 4. do dequantization from register of B into second pair of registers + //// 5. store (4) into fragment + //// 6. matmul aggregate into fragment C + //// 7. aggregate files of C into shared memory block C + //// 8. sum (7) + //// 9. write outputs to matmul output matrix #if __CUDA_ARCH__ >= 750 using namespace nvcuda; int col_offset = blockIdx.x *32; @@ -3911,6 +3766,8 @@ MAKE_PreconditionStatic8bit1State(RMSPROP, half) MAKE_PreconditionStatic8bit1State(RMSPROP, float) MAKE_PreconditionStatic8bit1State(LION, half) MAKE_PreconditionStatic8bit1State(LION, float) +MAKE_PreconditionStatic8bit1State(ADAGRAD, half) +MAKE_PreconditionStatic8bit1State(ADAGRAD, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ @@ -3930,6 +3787,9 @@ MAKE_optimizerStatic8bit1State(RMSPROP, half) MAKE_optimizerStatic8bit1State(RMSPROP, float) MAKE_optimizerStatic8bit1State(LION, half) MAKE_optimizerStatic8bit1State(LION, float) +MAKE_optimizerStatic8bit1State(ADAGRAD, half) +MAKE_optimizerStatic8bit1State(ADAGRAD, float) + #define MAKE_PreconditionStatic8bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ @@ -4075,3 +3935,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) + +template __device__ void printnonzero(float *A, int num_values, const char*strval); +template __device__ void printnonzero(half *A, int num_values, const char*strval); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index a7fe3d700..15f31cbed 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -9,7 +9,6 @@ #ifndef kernels #define kernels -//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 68ee919f0..ade3b13d1 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -91,13 +91,6 @@ template void dequantizeBlockwise(float *code, unsign } -//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) -//{ -// int num_blocks = (colsB+32-1)/32; -// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); -// CUDA_CHECK_RETURN(cudaPeekAtLastError()); -//} - template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, @@ -362,10 +355,6 @@ template int get_leading_dim(int dim1, int dim2) } } -template int get_leading_dim(int dim1, int dim2); -template int get_leading_dim(int dim1, int dim2); -template int get_leading_dim(int dim1, int dim2); - template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { #ifdef NO_CUBLASLT @@ -411,15 +400,6 @@ template void trans #endif } -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); - template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { #ifdef NO_CUBLASLT @@ -693,9 +673,9 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //cout << m << endl; //cout << n << endl; //cout << k << endl; - //if(bits == 32) + if(bits == 32) //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); gemm_device<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); @@ -841,6 +821,9 @@ MAKE_optimizerStatic8bit(RMSPROP, half) MAKE_optimizerStatic8bit(RMSPROP, float) MAKE_optimizerStatic8bit(LION, half) MAKE_optimizerStatic8bit(LION, float) +MAKE_optimizerStatic8bit(ADAGRAD, half) +MAKE_optimizerStatic8bit(ADAGRAD, float) + #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ @@ -849,6 +832,7 @@ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); @@ -862,4 +846,15 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); -MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2);