Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 40 additions & 54 deletions ggml/src/ggml-cuda/conv2d-mm.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "conv2d-mm.cuh"
#include "convert.cuh"

#include <cuda_runtime.h>

Expand All @@ -13,6 +14,8 @@

#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))

typedef uint32_t uint;

uint32_t ceil_div(uint32_t M, uint32_t N);
int get_sm_count();

Expand Down Expand Up @@ -51,29 +54,20 @@ __align__(16) struct Params {
uint32_t nb2;
uint32_t nb3;

uint32_t KWmp;
uint32_t KWL;
uint32_t KWKHmp;
uint32_t KWKHL;
uint32_t OWmp;
uint32_t OWL;
uint32_t OWOHmp;
uint32_t OWOHL;
uint3 KW_fastdiv;
uint3 KWKH_fastdiv;
uint3 OW_fastdiv;
uint3 OWOH_fastdiv;
};

__constant__ __device__ Params dp;

// see init_fastdiv_values in ggml-vulkan.cpp
__inline__ __device__ uint fastdiv(uint n, uint mp, uint L) {
return (__umulhi(n, mp) + n) >> L;
}

// --> conv_2d kernel modified to function as a matmul
template <uint BS_K, uint BS_NPQ, uint BS_CRS, uint TS_K, uint TS_NPQ, uint WG_SIZE, uint VEC_SIZE>
template <typename T, uint BS_K, uint BS_NPQ, uint BS_CRS, uint TS_K, uint TS_NPQ, uint WG_SIZE, uint VEC_SIZE>
__global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
uint NPQ,
uint CRS,
const float * knl_data,
const T * knl_data,
const float * src_data,
float * dst_data) {
// Each block computes a tile of the result of size BS_K*BS_NPQ
Expand All @@ -98,7 +92,8 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
const uint T_y = threadIdx.x / NT_x;
const uint T_x = threadIdx.x % NT_x;

__shared__ float Ash[BS_K * BS_CRS];
// __shared__ float Ash[BS_K * BS_CRS];
__shared__ T Ash[BS_K * BS_CRS];
__shared__ float Bsh[BS_CRS * BS_NPQ];

const uint Ar = threadIdx.x / BS_CRS;
Expand Down Expand Up @@ -135,10 +130,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
#else
uint32_t CRS_idx_a = idx_CRS + Ac; //Global CRS_idx (column index of A)
//uint32_t Cin_idx_a = CRS_idx_a / (dp.KW*dp.KH);
uint32_t Cin_idx_a = fastdiv(CRS_idx_a, dp.KWKHmp, dp.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
uint32_t Cin_idx_a = fastdiv(CRS_idx_a, dp.KWKH_fastdiv); // divide by (p.KW * p.KH); / (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * dp.KW * dp.KH;
//uint32_t KH_idx_a = (CRS_idx_a - Cin_idx_a*dp.KW*dp.KH) / dp.KW;
uint32_t KH_idx_a = fastdiv(CRS_remainder, dp.KWmp, dp.KWL); // divide by p.KW;
uint32_t KH_idx_a = fastdiv(CRS_remainder, dp.KW_fastdiv); // divide by p.KW;
//uint32_t KW_idx_a = CRS_idx_a - Cin_idx_a*dp.KW*dp.KH - KH_idx_a*dp.KW; // unused
#endif

Expand All @@ -148,9 +143,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
// General addressing (does not assume contiguity)
//const uint32_t knl_idx = KW_idx_a + KH_idx_a*dp.nb01 + Cin_idx_a*dp.nb02 + K_idx_a*dp.nb03;
// Contiguous addressing
float val = knl_data[min(CRS_idx_a + K_idx_a * dp.nb03, K * CRS - 1)];
T val = knl_data[min(CRS_idx_a + K_idx_a * dp.nb03, K * CRS - 1)];
if (CRS_idx_a >= CRS || K_idx_a >= K) {
val = 0.0;
val = (T)0.0;
}

#ifdef A_TRANS
Expand All @@ -173,10 +168,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
// Compute indices for N, OH, OW from NPQ_idx
const uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + Bc; /* Global NPQ index (column index of B) */
//const uint32_t N_idx = NPQ_idx / (dp.OH*dp.OW);
uint32_t N_idx = fastdiv(NPQ_idx, dp.OWOHmp, dp.OWOHL); // divide by p.OH * p.OW;
uint32_t N_idx = fastdiv(NPQ_idx, dp.OWOH_fastdiv); // divide by p.OH * p.OW;
uint32_t NPQ_remainder = NPQ_idx - N_idx * dp.OH * dp.OW;
//const uint32_t OH_idx = (NPQ_idx - N_idx*dp.OH*dp.OW) / dp.OW;
uint32_t OH_idx = fastdiv(NPQ_remainder, dp.OWmp, dp.OWL); // divide by p.OW;
uint32_t OH_idx = fastdiv(NPQ_remainder, dp.OW_fastdiv); // divide by p.OW;
const uint32_t OW_idx = NPQ_idx - N_idx * dp.OH * dp.OW - OH_idx * dp.OW;

#ifdef USE_COLLECTIVES
Expand All @@ -188,10 +183,10 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
// Compute indices KH, KW, Cin from CRS_idx
uint32_t CRS_idx_b = idx_CRS + r_offset + Br;
//uint32_t Cin_idx_b = CRS_idx_b / (dp.KW*dp.KH);
uint32_t Cin_idx_b = fastdiv(CRS_idx_b, dp.KWKHmp, dp.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
uint32_t Cin_idx_b = fastdiv(CRS_idx_b, dp.KWKH_fastdiv); // divide by (p.KW * p.KH); / (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH;
//uint32_t KH_idx_b = (CRS_idx_b - Cin_idx_b*dp.KW*dp.KH) / dp.KW;
uint32_t KH_idx_b = fastdiv(CRS_remainder, dp.KWmp, dp.KWL); // divide by p.KW;
uint32_t KH_idx_b = fastdiv(CRS_remainder, dp.KW_fastdiv); // divide by p.KW;
uint32_t KW_idx_b = CRS_idx_b - Cin_idx_b * dp.KW * dp.KH - KH_idx_b * dp.KW;
#endif

Expand Down Expand Up @@ -235,9 +230,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
# else
uint32_t col_offset = (T_y * TS_K + T_ly);
# endif
regA[T_ly] = Ash[CRS_lidx * BS_K + col_offset];
regA[T_ly] = ggml_cuda_cast<float>(Ash[CRS_lidx * BS_K + col_offset]);
#else
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * BS_CRS + CRS_lidx];
regA[T_ly] = ggml_cuda_cast<float>(Ash[(T_y * TS_K + T_ly) * BS_CRS + CRS_lidx]);
#endif
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; ++T_lx) {
Expand Down Expand Up @@ -267,9 +262,9 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
const uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
const uint32_t NPQ_idx_c = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
//const uint32_t N_idx_c = NPQ_idx_c / (dp.OH*dp.OW);
const uint32_t N_idx_c = fastdiv(NPQ_idx_c, dp.OWOHmp, dp.OWOHL); // divide by p.OH * p.OW;
const uint32_t N_idx_c = fastdiv(NPQ_idx_c, dp.OWOH_fastdiv); // divide by p.OH * p.OW;
//const uint32_t OH_idx_c = (NPQ_idx_c - N_idx_c*dp.OH*dp.OW) / dp.OW;
const uint32_t OH_idx_c = fastdiv(NPQ_idx_c - N_idx_c * dp.OH * dp.OW, dp.OWmp, dp.OWL); // divide by p.OW;
const uint32_t OH_idx_c = fastdiv(NPQ_idx_c - N_idx_c * dp.OH * dp.OW, dp.OW_fastdiv); // divide by p.OW;
const uint32_t OW_idx_c = NPQ_idx_c - N_idx_c * dp.OH * dp.OW - OH_idx_c * dp.OW;
const uint32_t dst_idx = OW_idx_c + OH_idx_c * dp.nb1 + K_idx * dp.nb2 + N_idx_c * dp.nb3;
if (K_idx < K && NPQ_idx_c < NPQ) {
Expand All @@ -279,22 +274,6 @@ __global__ void __launch_bounds__(WG_SIZE, 1) mm(uint K,
}
}

// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
// and a shift:
//
// n/d = (mulhi(n, mp) + n) >> L;
static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) {
// compute L = ceil(log2(d));
L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}

mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
}

constexpr int conv_shapes[][NUM_VARIANTS] = {
{ 128, 64, 32 }, // BS_K
{ 16, 32, 16 }, // BS_CRS
Expand Down Expand Up @@ -340,19 +319,26 @@ void ggml_cuda_op_conv_2d_variant(ggml_backend_cuda_context & ctx,
uint32_t NB_K = CEIL_DIV(p.Cout, BS_K);
uint32_t NB_NPQ = CEIL_DIV(NPQ, BS_NPQ);

cudaStream_t stream = ctx.stream();
cudaMemcpyToSymbol(dp, &p, sizeof(Params));
// cudaMemcpyToSymbolAsync(dp, &p, sizeof(Params), 0, cudaMemcpyHostToDevice, stream);

// Kernel arguments
float * src0_data = (float *) src0->data;
float * src1_data = (float *) src1->data;
float * dst_data = (float *) dst->data;

dim3 gridDim(NB_K, NB_NPQ);
dim3 blockDim(WG_SIZE);
cudaStream_t stream = ctx.stream();

mm<BS_K, BS_NPQ, BS_CRS, TS_K, TS_NPQ, WG_SIZE, VEC_SIZE>
if(src0->type == GGML_TYPE_F16) {
half *src0_data = (half *) src0->data;
mm<half, BS_K, BS_NPQ, BS_CRS, TS_K, TS_NPQ, WG_SIZE, VEC_SIZE>
<<<gridDim, blockDim, 0, stream>>>(p.Cout, NPQ, p.Cin * p.KW * p.KH, src0_data, src1_data, dst_data);
} else {
float *src0_data = (float *) src0->data;
mm<float, BS_K, BS_NPQ, BS_CRS, TS_K, TS_NPQ, WG_SIZE, VEC_SIZE>
<<<gridDim, blockDim, 0, stream>>>(p.Cout, NPQ, p.Cin * p.KW * p.KH, src0_data, src1_data, dst_data);
}

}

void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand All @@ -372,13 +358,13 @@ void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

GGML_TENSOR_BINARY_OP_LOCALS

GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(half));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));

Expand Down Expand Up @@ -413,10 +399,10 @@ void ggml_cuda_op_conv2d_mm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
p.nb3 = static_cast<uint32_t>(nb3 / nb0);

init_fastdiv_values(p.KW, p.KWmp, p.KWL);
init_fastdiv_values(p.KW * p.KH, p.KWKHmp, p.KWKHL);
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW * p.OH, p.OWOHmp, p.OWOHL);
p.KW_fastdiv = init_fastdiv_values(p.KW);
p.KWKH_fastdiv = init_fastdiv_values(p.KW * p.KH);
p.OW_fastdiv = init_fastdiv_values(p.OW);
p.OWOH_fastdiv = init_fastdiv_values(p.OW * p.OH);

GGML_ASSERT(ne03 == ne2);
GGML_ASSERT(ne02 == ne12);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2463,7 +2463,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
break;
case GGML_OP_CONV_2D:
if (!getenv("GGML_CUDA_USE_LEGACY_CONV") &&
(dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 &&
(dst->src[1]->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32)) {
ggml_cuda_op_conv2d_mm(ctx, dst);
} else {
Expand Down
69 changes: 61 additions & 8 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <string_view>
#include <thread>
#include <vector>
#include <map>

static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
size_t nels = ggml_nelements(tensor);
Expand Down Expand Up @@ -6615,14 +6616,66 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
{ 16, 3, 512, 128, 8 },
};

for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (auto act_case : cases) {
// Direct CONV_2D
test_cases.emplace_back(new test_conv_2d(
{ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
{ act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
kernel_type, 1, 1, 0, 0, 1, 1, false));
}
// for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
// for (auto act_case : cases) {
// // Direct CONV_2D
// test_cases.emplace_back(new test_conv_2d(
// { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
// { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
// kernel_type, 1, 1, 0, 0, 1, 1, false));
// }
// }

// Stable-diffusion layers
std::map<std::string, uint32_t> idx_sd{
{ "iw", 0 },
{ "ih", 1 },
{ "kw", 2 },
{ "kh", 3 },
{ "Cout", 4 },
{ "Cin", 5 },
{ "B", 6 },
};

// Input image size
uint32_t w = 768;
uint32_t h = 1024;

// Number of filters (base)
uint32_t Cout_b = 128;
uint32_t Cin_b = 128;

std::vector<std::array<uint32_t, 7>> cases_sd = {
{ w / 8, h / 8, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x10 (called 10 times)
{ w / 4, h / 4, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x7
{ w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 2, 1 }, // x5
{ w, h, 3, 3, Cout_b, Cin_b, 1 }, // x5
{ w / 8, h / 8, 1, 1, Cout_b * 4, Cin_b * 4, 1 }, // x4
{ w / 8, h / 8, 1, 1, 4, 4, 1 },
{ w / 8, h / 8, 3, 3, Cout_b * 4, 4, 1 },

{ w / 2, h / 2, 3, 3, Cout_b * 4, Cin_b * 4, 1 },
{ w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 4, 1 },
{ w / 2, h / 2, 1, 1, Cout_b * 2, Cin_b * 4, 1 },

{ w, h, 3, 3, Cout_b, Cin_b * 2, 1 },
{ w, h, 1, 1, Cout_b, Cin_b * 2, 1 },
{ w, h, 3, 3, Cout_b * 2, Cin_b * 2, 1 },

{ w, h, 3, 3, 3, Cin_b, 1 },
};

for (auto act_case : cases_sd) {
GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1);
GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1);

uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0;
uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0;

test_cases.emplace_back(new test_conv_2d(
{ act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] },
{ act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] },
GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false));
}

test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
Expand Down