You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
// Define TILE_SIZE for shared memory tiling
#define TILE_SIZE 16
// CUDA kernel for optimized low-precision GEMM operation using shared memory and tiling global void gemm_lowbit_kernel(
const half* restrict A,
const half* restrict B,
half* restrict C,
int M, int N, int K)
{
// Shared memory for tiles of A and B shared half As[TILE_SIZE][TILE_SIZE]; shared half Bs[TILE_SIZE][TILE_SIZE];
// Calculate row and column indices of C element to work on
int row = blockIdx.y * TILE_SIZE + threadIdx.y; // Row index of C to compute
int col = blockIdx.x * TILE_SIZE + threadIdx.x; // Column index of C to compute
// Initialize the accumulator to zero
float sum = 0.0f;
// Loop over tiles
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {
// Load elements of A and B into shared memory if within bounds
if (row < M && (t * TILE_SIZE + threadIdx.x) < K)
As[threadIdx.y][threadIdx.x] = A[row * K + t * TILE_SIZE + threadIdx.x];
else
As[threadIdx.y][threadIdx.x] = __float2half(0.0f);
if (col < N && (t * TILE_SIZE + threadIdx.y) < K)
Bs[threadIdx.y][threadIdx.x] = B[(t * TILE_SIZE + threadIdx.y) * N + col];
else
Bs[threadIdx.y][threadIdx.x] = __float2half(0.0f);
__syncthreads(); // Synchronize to ensure data is loaded
// Compute partial dot product for this tile
#pragma unroll
for (int k = 0; k < TILE_SIZE; ++k) {
half a_element = As[threadIdx.y][k];
half b_element = Bs[k][threadIdx.x];
sum += __half2float(__hmul(a_element, b_element));
}
__syncthreads(); // Synchronize before loading the next tile
}
// Write the result to the output matrix if within bounds
if (row < M && col < N)
C[row * N + col] = __float2half(sum);
}
// Wrapper function to call the CUDA kernel
void gemm_lowbit_cuda(at::Tensor a, at::Tensor b, at::Tensor c, int M, int N, int K) {
// Ensure that input tensors are contiguous and on the correct device
a = a.contiguous();
b = b.contiguous();
c = c.contiguous();
// Define block and grid dimensions
dim3 threads(TILE_SIZE, TILE_SIZE);
dim3 blocks((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
// Get the CUDA stream from PyTorch
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Launch the optimized kernel
gemm_lowbit_kernel<<<blocks, threads, 0, stream>>>(
reinterpret_cast<const half*>(a.data_ptr<at::Half>()),
reinterpret_cast<const half*>(b.data_ptr<at::Half>()),
reinterpret_cast<half*>(c.data_ptr<at::Half>()),
M, N, K);
// Check for kernel launch errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("CUDA kernel launch failed: %s\n", cudaGetErrorString(err));
}
}`
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
The text was updated successfully, but these errors were encountered:
Just a few suggested CUDA upgrades:
`#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
// Define TILE_SIZE for shared memory tiling
#define TILE_SIZE 16
// CUDA kernel for optimized low-precision GEMM operation using shared memory and tiling
global void gemm_lowbit_kernel(
const half* restrict A,
const half* restrict B,
half* restrict C,
int M, int N, int K)
{
// Shared memory for tiles of A and B
shared half As[TILE_SIZE][TILE_SIZE];
shared half Bs[TILE_SIZE][TILE_SIZE];
}
// Wrapper function to call the CUDA kernel
void gemm_lowbit_cuda(at::Tensor a, at::Tensor b, at::Tensor c, int M, int N, int K) {
// Ensure that input tensors are contiguous and on the correct device
a = a.contiguous();
b = b.contiguous();
c = c.contiguous();
}`
Upvote & Fund
The text was updated successfully, but these errors were encountered: