Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cuda source cleanup , refactor and fixes #1328

Merged
Merged
Prev Previous commit
Next Next commit
remove vector load
  • Loading branch information
abhilash1910 committed Aug 21, 2024
commit d529e451e0326ce8c8e36c751d4d236000b29f0c
15 changes: 0 additions & 15 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3001,21 +3001,6 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// 9. write outputs to matmul output matrix
//}

template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f)
{
if(limit_base + ITEMS <= limit)
reinterpret_cast<TCAST*>(local)[0] = reinterpret_cast<TCAST*>(buffer)[idx/ITEMS];
else
{
for(int k = 0; k < ITEMS; k++)
{
if(limit_base + k < limit)
local[k] = buffer[idx+k];
else
local[k] = (T)zero_value;
}
}
}

#define WARPS 3
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
Expand Down
Loading