Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test-backend-ops : add performance eval mode + improve CUDA repeat and binary broadcast ops performance #636

Merged
merged 18 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 67 additions & 78 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define WARP_SIZE 32
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses

#define CUDA_ADDMUL_BLOCK_SIZE 256
#define CUDA_GELU_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
Expand Down Expand Up @@ -501,6 +500,10 @@ static size_t g_scratch_offset = 0;

static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};

static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
}

static __device__ __forceinline__ float op_add(const float a, const float b) {
return a + b;
}
Expand All @@ -515,29 +518,36 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {

template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0,/* int ne1, int ne2, */int ne3,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s10,*/ int s11, int s12, int s13) {
const int i0 = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = blockIdx.y;
const int i2 = blockIdx.z / ne3;
const int i3 = blockIdx.z % ne3;
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;

if (i0 >= ne0) {
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}

const int i10 = i0 % ne10;
const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;

const size_t i_dst = i3*s3 + i2*s2 + i1*s1 + i0;
const size_t i_src0 = i_dst;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11 + i10;
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i_src0;

const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst;

for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
const int i10 = i0 % ne10;

dst[i_dst] = (dst_t)bin_op((float)src0[i_src0], (float)src1[i_src1]);
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
}
}

static __global__ void gelu_f32(const float * x, float * dst, const int k) {
Expand Down Expand Up @@ -4859,11 +4869,22 @@ struct bin_bcast_cuda {
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);

const int num_blocks_x = (ne0 + CUDA_ADDMUL_BLOCK_SIZE - 1) / CUDA_ADDMUL_BLOCK_SIZE;
dim3 num_blocks(num_blocks_x, ne1, ne2*ne3);
static const int64_t block_size = 128;

dim3 block_dims;
block_dims.x = std::min<unsigned int>(ne0/2, block_size);
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y);

k_bin_bcast<bin_op><<<num_blocks, CUDA_ADDMUL_BLOCK_SIZE, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0,/* ne1, ne2, */ne3,
dim3 block_nums(
(ne0/2 + block_dims.x - 1) / block_dims.x,
(ne1 + block_dims.y - 1) / block_dims.y,
(ne2*ne3 + block_dims.z - 1) / block_dims.z
);

k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */s1, s2, s3,
/* s10,*/ s11, s12, s13);
Expand Down Expand Up @@ -6096,63 +6117,6 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
}
}

static void ggml_cuda_op_repeat(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
// guaranteed to be an integer due to the check in ggml_can_repeat
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];

const size_t nb0 = dst->nb[0];
const size_t nb1 = dst->nb[1];
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];

const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2];
const size_t nb03 = src0->nb[3];

const int nr0 = (int)(ne0/ne00);
const int nr1 = (int)(ne1/ne01);
const int nr2 = (int)(ne2/ne02);
const int nr3 = (int)(ne3/ne03);

// TODO: support for transposed / permuted tensors
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));

// TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors
for (int i3 = 0; i3 < nr3; i3++) {
for (int k3 = 0; k3 < ne03; k3++) {
for (int i2 = 0; i2 < nr2; i2++) {
for (int k2 = 0; k2 < ne02; k2++) {
for (int i1 = 0; i1 < nr1; i1++) {
for (int k1 = 0; k1 < ne01; k1++) {
for (int i0 = 0; i0 < nr0; i0++) {
CUDA_CHECK(cudaMemcpyAsync(
(char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
(const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
ne00*nb0, cudaMemcpyDeviceToDevice, stream));
}
}
}
}
}
}
}

(void) src1;
(void) src1_d;
}

static void ggml_cuda_op_get_rows(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
Expand Down Expand Up @@ -6215,7 +6179,16 @@ inline void ggml_cuda_op_bin_bcast(
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ASSERT(false);
}
}

static void ggml_cuda_op_repeat(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) {

ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream);

(void) src1;
(void) src1_d;
}

inline void ggml_cuda_op_add(
Expand Down Expand Up @@ -8393,7 +8366,8 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
break;
default:
return false;
} break;
}
break;
case GGML_OP_NORM:
func = ggml_cuda_norm;
break;
Expand Down Expand Up @@ -8842,10 +8816,10 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
UNUSED(backend);
}

static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * tensor) {
switch (tensor->op) {
static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
Expand All @@ -8854,7 +8828,23 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
return false;
}
break;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
struct ggml_tensor * a;
struct ggml_tensor * b;
if (op->op == GGML_OP_MUL_MAT) {
a = op->src[0];
b = op->src[1];
} else {
a = op->src[2];
b = op->src[1];
}
if (a->ne[3] != b->ne[3]) {
return false;
}
return true;
} break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
Expand All @@ -8868,7 +8858,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
case GGML_OP_MUL_MAT:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_CLAMP:
Expand Down
Loading