Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : group all experts in a single ggml_mul_mat_id #6505

Merged
merged 22 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 175 additions & 41 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(

if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool());
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
Expand All @@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
}
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();

ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool());
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
if (src1->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
Expand All @@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(), row_diff*src1_ncols);
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
Expand Down Expand Up @@ -1960,20 +1960,84 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
}
}

struct mmid_row_mapping {
int64_t i1;
int64_t i2;
};

static __global__ void k_copy_src1_to_contiguous(const char * src1_original, char * src1_contiguous,
int * cur_src1_row, mmid_row_mapping * row_mapping,
const char * ids_dev, int64_t i02, int64_t ids_nb1, int64_t ids_nb0,
int64_t ids_ne1, int64_t n_ids,
int64_t ne11,
size_t nb11, size_t nb12) {
int64_t iid1 = blockIdx.x;
int64_t id = blockIdx.y;

if (iid1 >= ids_ne1 || id >= n_ids) {
return;
}

const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0);

if (row_id_i != i02) {
return;
}

const int64_t i11 = id % ne11;
const int64_t i12 = iid1;

__shared__ int src1_row;
if (threadIdx.x == 0) {
src1_row = atomicAdd(cur_src1_row, 1);
row_mapping[src1_row] = {id, iid1};
}
__syncthreads();

const char * src1_row_original = src1_original + i11*nb11 + i12*nb12;
char * src1_row_contiguous = src1_contiguous + src1_row*nb11;

for (int i = threadIdx.x; i < nb11; i += blockDim.x) {
src1_row_contiguous[i] = src1_row_original[i];
}
}

static __global__ void k_copy_dst_from_contiguous(char * dst_original, const char * dst_contiguous,
const mmid_row_mapping * row_mapping,
int64_t n_rows,
int64_t nb1, int64_t nb2) {
int64_t i = blockIdx.x;

if (i >= n_rows) {
return;
}

const int64_t i1 = row_mapping[i].i1;
const int64_t i2 = row_mapping[i].i2;

const char * dst_row_contiguous = dst_contiguous + i*nb1;
char * dst_row_original = dst_original + i1*nb1 + i2*nb2;

for (int j = threadIdx.x; j < nb1; j += blockDim.x) {
dst_row_original[j] = dst_row_contiguous[j];
}
}

//#define MMID_MEMCPY

static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];

GGML_TENSOR_BINARY_OP_LOCALS

GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");

cudaStream_t stream = ctx.stream();

const size_t nb11 = src1->nb[1];
const size_t nb1 = dst->nb[1];

const int32_t id = ((int32_t *) dst->op_params)[0];
const int32_t n_as = src0->ne[2];
const int64_t n_as = ne02;
const int64_t n_ids = ids->ne[0];

std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
Expand All @@ -1982,27 +2046,47 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *

ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
ggml_tensor dst_row = *dst;

char * src0_original = (char *) src0->data;
char * src1_original = (char *) src1->data;
char * dst_original = (char *) dst->data;

src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[3] = src0->nb[2];
src0_row.nb[3] = nb02;

if (src1->ne[1] == 1) {
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;

GGML_ASSERT(row_id >= 0 && row_id < n_as);
dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;

src0_row.data = src0_original + row_id*src0->nb[2];
src1_row.data = src1_original + i01*src1->nb[1];
dst_row.data = dst_original + i01*dst->nb[1];
if (ne12 == 1) {
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);

ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
GGML_ASSERT(i02 >= 0 && i02 < n_as);

const int64_t i11 = id % ne11;
const int64_t i12 = iid1;

const int64_t i1 = id;
const int64_t i2 = i12;

src0_row.data = src0_original + i02*nb02;
src1_row.data = src1_original + i11*nb11 + i12*nb12;
dst_row.data = dst_original + i1*nb1 + i2*nb2;

ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
}
}
} else {
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
Expand All @@ -2011,55 +2095,104 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
src1_row.data = src1_contiguous.get();
dst_row.data = dst_contiguous.get();

for (int32_t row_id = 0; row_id < n_as; ++row_id) {
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);

if (row_id_i != row_id) {
continue;
}
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);

if (row_id_i != i02) {
continue;
}

GGML_ASSERT(row_id >= 0 && row_id < n_as);
GGML_ASSERT(i02 >= 0 && i02 < n_as);

CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
nb11, cudaMemcpyDeviceToDevice, stream));
num_src1_rows++;
#ifdef MMID_MEMCPY
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11,
src1_original + i11*nb11 + i12*nb12,
nb11, cudaMemcpyDeviceToDevice, stream));
#endif
num_src1_rows++;
}
}

if (num_src1_rows == 0) {
continue;
}

src0_row.data = src0_original + row_id*src0->nb[2];
#ifndef MMID_MEMCPY
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));

src1_row.ne[1] = num_src1_rows;
dst_row.ne[1] = num_src1_rows;
{
dim3 block_dims(std::min((uint)nb11, 1024u));
dim3 grid_dims(ids->ne[1], n_ids);
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
src1_original, src1_contiguous.get(),
dev_cur_src1_row.get(), dev_row_mapping.get(),
ids_dev, i02, ids->nb[1], ids->nb[0],
ids->ne[1], n_ids,
ne11,
nb11, nb12);
CUDA_CHECK(cudaGetLastError());
}
#endif

src0_row.data = src0_original + i02*nb02;

GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0);

src1_row.ne[1] = num_src1_rows;
src1_row.nb[1] = nb11;
src1_row.nb[2] = num_src1_rows*nb11;
src1_row.nb[3] = num_src1_rows*nb11;

dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;

ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);

#ifndef MMID_MEMCPY
{
dim3 block_dims(std::min((uint)nb1, 1024u));
dim3 grid_dims(num_src1_rows);
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
dst_original, dst_contiguous.get(),
dev_row_mapping.get(),
num_src1_rows, nb1, nb2);
CUDA_CHECK(cudaGetLastError());
}
#endif

#ifdef MMID_MEMCPY
num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);

if (row_id_i != row_id) {
continue;
}
if (row_id_i != i02) {
continue;
}

GGML_ASSERT(row_id >= 0 && row_id < n_as);
GGML_ASSERT(i02 >= 0 && i02 < n_as);

CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
nb1, cudaMemcpyDeviceToDevice, stream));
num_src1_rows++;
const int64_t i1 = id;
const int64_t i2 = iid1;

CUDA_CHECK(cudaMemcpyAsync(dst_original + i1*nb1 + i2*nb2,
dst_contiguous.get() + num_src1_rows*nb1,
nb1, cudaMemcpyDeviceToDevice, stream));
num_src1_rows++;
}
}
#endif
}
}
}
Expand Down Expand Up @@ -2487,7 +2620,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
const int min_batch_size = 32;

return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);

GGML_UNUSED(backend);
}
Expand Down
2 changes: 2 additions & 0 deletions ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
vals[ix] = x0[ix];
}

__syncthreads();

#pragma unroll
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
if (need_check && i0 + iy + 2*threadIdx.x >= k) {
Expand Down
Loading
Loading