Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add affine transform kernel with activation functions #6

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions optimus/ops/kernels/fp32_gemm.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <cmath>
#include <iostream>
#include "optimus/ops/kernels/smem_load.cuh"
#include "optimus/ops/kernels/fp32_smem_load.cuh"
#include "optimus/utils/array_utils.h"
#include "optimus/utils/cuda_utils.h"

Expand Down Expand Up @@ -42,12 +42,17 @@ __global__ void __launch_bounds__(NUM_THREADS)
chunk_idx++) {

if (M % 4 == 0 && N % 4 == 0 && K % 4 == 0) {
Float4VectorizedSMeMLoad<T, M_chunk_size, N_chunk_size,
K_chunk_size>(A, B, A_chunk, B_chunk,
chunk_idx, M, N, K);
FP32Float4VectorizedSMeMLoad<T, M_chunk_size, N_chunk_size,
K_chunk_size>(A, B, A_chunk, B_chunk,
chunk_idx, M, N, K);
} else if (M % 2 == 0 && N % 2 == 0 && K % 2 == 0) {
FP32Float2VectorizedSMeMLoad<T, M_chunk_size, N_chunk_size,
K_chunk_size>(A, B, A_chunk, B_chunk,
chunk_idx, M, N, K);
} else {
NonVectorizedSMeMLoad<T, M_chunk_size, N_chunk_size, K_chunk_size>(
A, B, A_chunk, B_chunk, chunk_idx, M, N, K);
FP32NonVectorizedSMeMLoad<T, M_chunk_size, N_chunk_size,
K_chunk_size>(A, B, A_chunk, B_chunk,
chunk_idx, M, N, K);
}

__syncthreads();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ namespace optimus {
namespace ops {
template <typename T, const int M_chunk_size, const int N_chunk_size,
const int K_chunk_size>
__device__ void NonVectorizedSMeMLoad(const T *__restrict__ A,
const T *__restrict__ B,
T A_chunk[][M_chunk_size + 1],
T B_chunk[][K_chunk_size + 1],
const int chunk_idx, const int M,
const int N, const int K) {
__device__ void FP32NonVectorizedSMeMLoad(const T *__restrict__ A,
const T *__restrict__ B,
T A_chunk[][M_chunk_size + 1],
T B_chunk[][K_chunk_size + 1],
const int chunk_idx, const int M,
const int N, const int K) {
const int A_thread_cols = N_chunk_size;
const int A_thread_rows = blockDim.x / A_thread_cols;
const int thread_row_in_A = threadIdx.x / A_thread_cols;
Expand Down Expand Up @@ -51,12 +51,59 @@ __device__ void NonVectorizedSMeMLoad(const T *__restrict__ A,

template <typename T, const int M_chunk_size, const int N_chunk_size,
const int K_chunk_size>
__device__ void Float4VectorizedSMeMLoad(const T *__restrict__ A,
const T *__restrict__ B,
T A_chunk[][M_chunk_size + 1],
T B_chunk[][K_chunk_size + 1],
const int chunk_idx, const int M,
const int N, const int K) {
__device__ void FP32Float2VectorizedSMeMLoad(const T *__restrict__ A,
const T *__restrict__ B,
T A_chunk[][M_chunk_size + 1],
T B_chunk[][K_chunk_size + 1],
const int chunk_idx, const int M,
const int N, const int K) {
const int A_thread_cols = N_chunk_size / 2;
const int A_thread_rows = blockDim.x / A_thread_cols;
const int thread_row_in_A = threadIdx.x / A_thread_cols;
const int thread_col_in_A = threadIdx.x % A_thread_cols;

const int B_thread_cols = K_chunk_size / 2;
const int B_thread_rows = blockDim.x / B_thread_cols;
const int thread_row_in_B = threadIdx.x / B_thread_cols;
const int thread_col_in_B = threadIdx.x % B_thread_cols;

for (int A_row_offset = 0; A_row_offset < M_chunk_size;
A_row_offset += A_thread_rows) {
const int current_thread_row_in_A = thread_row_in_A + A_row_offset;
const int current_A_row =
(blockIdx.x * M_chunk_size) + current_thread_row_in_A;
const int current_A_col =
chunk_idx * N_chunk_size + thread_col_in_A * 2;
const int A_index = current_A_row * N + current_A_col;

const float2 A_load = reinterpret_cast<const float2 *>(&A[A_index])[0];
A_chunk[thread_col_in_A * 2 + 0][current_thread_row_in_A] = A_load.x;
A_chunk[thread_col_in_A * 2 + 1][current_thread_row_in_A] = A_load.y;
}

for (int B_row_offset = 0; B_row_offset < N_chunk_size;
B_row_offset += B_thread_rows) {
const int current_thread_row_in_B = thread_row_in_B + B_row_offset;
const int current_B_row =
chunk_idx * N_chunk_size + current_thread_row_in_B;
const int current_B_col =
(blockIdx.y * K_chunk_size) + thread_col_in_B * 2;
const int B_index = current_B_row * K + current_B_col;

const float2 B_load = reinterpret_cast<const float2 *>(&B[B_index])[0];
B_chunk[current_thread_row_in_B][thread_col_in_B * 2 + 0] = B_load.x;
B_chunk[current_thread_row_in_B][thread_col_in_B * 2 + 1] = B_load.y;
}
}

template <typename T, const int M_chunk_size, const int N_chunk_size,
const int K_chunk_size>
__device__ void FP32Float4VectorizedSMeMLoad(const T *__restrict__ A,
const T *__restrict__ B,
T A_chunk[][M_chunk_size + 1],
T B_chunk[][K_chunk_size + 1],
const int chunk_idx, const int M,
const int N, const int K) {
const int A_thread_cols = N_chunk_size / 4;
const int A_thread_rows = blockDim.x / A_thread_cols;
const int thread_row_in_A = threadIdx.x / A_thread_cols;
Expand Down