Skip to content

Commit

Permalink
Test different cache hints
Browse files Browse the repository at this point in the history
The weight data is purely streaming workload, try cache hints to mark in the load instructions. This can potentially free up cache to hold other stuff.
  • Loading branch information
ankan-ban committed Oct 19, 2023
1 parent e0acf27 commit 6cea66c
Showing 1 changed file with 50 additions and 4 deletions.
54 changes: 50 additions & 4 deletions gpu_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,52 @@
#include <cuda_runtime_api.h>
#include <cub/cub.cuh>


// utility function to load from memory (try different cache hints)
#define USE_NO_CACHE_ALLOCATE_FOR_WEIGHT_LOADS 1
#define USE_LDCS_FOR_WEIGHT_LOADS 0

__forceinline__ __device__ uint4 loadFromMem(const uint4* ptr) {
uint4 ret;
#if USE_NO_CACHE_ALLOCATE_FOR_WEIGHT_LOADS
asm volatile("ld.global.L1::no_allocate.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
#elif USE_LDCS_FOR_WEIGHT_LOADS
ret = __ldcs(ptr);
#else
ret = *ptr;
#endif
return ret;
}

__forceinline__ __device__ uint32_t loadFromMem(const uint32_t* ptr) {
uint32_t ret;
#if USE_NO_CACHE_ALLOCATE_FOR_WEIGHT_LOADS
asm volatile("ld.global.L1::no_allocate.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
#elif USE_LDCS_FOR_WEIGHT_LOADS
ret = __ldcs(ptr);
#else
ret = *ptr;
#endif
return ret;
}

__forceinline__ __device__ half loadFromMem(const half* ptr) {
half ret;
#if USE_NO_CACHE_ALLOCATE_FOR_WEIGHT_LOADS
uint16_t temp;
asm volatile("ld.global.L1::no_allocate.u16 %0, [%1];" : "=h"(temp) : "l"(ptr));
ret = __ushort_as_half(temp);
#elif USE_LDCS_FOR_WEIGHT_LOADS
ret = __ldcs(ptr);
#else
ret = *ptr;
#endif
return ret;
}




// ----------------------------------------------------------------------------
// GPU kernels

Expand Down Expand Up @@ -76,7 +122,7 @@ __global__ void mat_vec_kernel(half* op, const half* ip, const half* wt, int n,
if (j < n) {
half w[8];
half ip[8];
*((uint4*)(&w)) = *((uint4*)(&weight[index * w_row_stride + j]));
*((uint4*)(&w)) = loadFromMem((uint4*)(&weight[index * w_row_stride + j]));
*((uint4*)(&ip)) = *((uint4*)(&input[j]));
for (int el = 0; el < 8; el++)
sum += float(w[el]) * float(ip[el]);
Expand Down Expand Up @@ -132,15 +178,15 @@ __global__ void mat_vec_kernel_int4(half* __restrict__ output, const half* __res

float sum = 0;
for (int ygq = 0; ygq * 128 + threadIdx.x * 4 < packed_weights_height; ygq++) { // each iteration of this loop covers 8 x 128 elements in y dimension of weight matrix (weight matrix is column major)
uint32_t packed_q_z = q_zeros[index * packed_zeros_height + ygq];
uint32_t packed_q_z = loadFromMem(&q_zeros[index * packed_zeros_height + ygq]);

// load weights in one go (32 elements from weight matrix loaded by each thread in one read)
uint32_t loaded_packed_wts[4];
*((uint4*)(&loaded_packed_wts[0])) = *((uint4*)(&q_weight[index * packed_weights_height + ygq * 128 + threadIdx.x * 4]));
*((uint4*)(&loaded_packed_wts[0])) = loadFromMem((uint4*)(&q_weight[index * packed_weights_height + ygq * 128 + threadIdx.x * 4]));

int group_y = ygq * 8 + (threadIdx.x / 4);
float q_z = (float)(packed_q_z >> (4 * (threadIdx.x / 4)) & 0xF);
float scale = (float)scales[index * scales_height + group_y];
float scale = (float)loadFromMem(&scales[index * scales_height + group_y]);
int y_base = ygq * 1024 + threadIdx.x * 32;

for (int qi = 0; qi < 4; qi++) { // each iteration of this loop covers 256 elements in y dimension of weight matrix
Expand Down

0 comments on commit 6cea66c

Please sign in to comment.