Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : group all experts in a single ggml_mul_mat_id #6505

Merged
merged 22 commits into from
Apr 18, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cleanup
ggml-ci
  • Loading branch information
slaren committed Apr 17, 2024
commit 997a9b5bd2f2c9c16067cea3a901be96c1203b7d
25 changes: 10 additions & 15 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -11018,11 +11018,6 @@ static void ggml_compute_forward_mul_mat_id(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;

//GGML_ASSERT(ne0 == ne01);
//GGML_ASSERT(ne1 == ne11);
//GGML_ASSERT(ne2 == ne12);
//GGML_ASSERT(ne3 == ne13);

// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
Expand All @@ -11041,8 +11036,13 @@ static void ggml_compute_forward_mul_mat_id(
(char *) params->wdata :
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));

struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};

int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]

if (params->type == GGML_TASK_TYPE_INIT) {
if (ith != 0) {
Expand All @@ -11069,9 +11069,6 @@ static void ggml_compute_forward_mul_mat_id(
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));

#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
#define MAKE_I64(lo, hi) (((int64_t)(lo)) | (((int64_t)(hi)) << 32))
#define LO_I64(i64) ((int32_t)(i64))
#define HI_I64(i64) ((int32_t)((i64) >> 32))

// group rows by src0 matrix
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
Expand All @@ -11080,7 +11077,7 @@ static void ggml_compute_forward_mul_mat_id(

assert(i02 >= 0 && i02 < n_as);

MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = MAKE_I64(iid1, id);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
matrix_row_counts[i02] += 1;
}
}
Expand Down Expand Up @@ -11143,10 +11140,11 @@ static void ggml_compute_forward_mul_mat_id(
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t _i12 = ir1; // logical row index for this expert

const int id = HI_I64(MMID_MATRIX_ROW(cur_a, _i12)); // selected expert index
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
const int id = row_mapping.i1; // selected expert index

const int64_t i11 = id % ne11;
const int64_t i12 = LO_I64(MMID_MATRIX_ROW(cur_a, _i12)); // row index in src1
const int64_t i12 = row_mapping.i2; // row index in src1

const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
Expand Down Expand Up @@ -11177,9 +11175,6 @@ static void ggml_compute_forward_mul_mat_id(
}

#undef MMID_MATRIX_ROW
#undef MAKE_I64
#undef LO_I64
#undef HI_I64
}

// ggml_compute_forward_out_prod
Expand Down
Loading