Skip to content

Commit

Permalink
metal : add more general support for ggml_get_rows + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 9, 2023
1 parent 9064b1c commit 2cbcba8
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 25 deletions.
16 changes: 9 additions & 7 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_GET_ROWS:
case GGML_OP_CONCAT:
case GGML_OP_ADD:
case GGML_OP_MUL:
Expand All @@ -828,7 +829,6 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_MUL_MAT_ID:
return true;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS:
{
return op->ne[0] % 4 == 0;
}
Expand Down Expand Up @@ -1568,16 +1568,18 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented");
}

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];

const int64_t n = ggml_nelements(src1);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{
Expand Down
62 changes: 56 additions & 6 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3223,21 +3223,69 @@ kernel void kernel_get_rows(
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg[[threads_per_threadgroup]]) {
const int i = tgpig;
const int r = ((device int32_t *) src1)[i];
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];

for (int ind = tiitg; ind < ne00/16; ind += tptg) {
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) {
float4x4 temp;
dequantize_func(
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
*(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
}
}

kernel void kernel_get_rows_f32(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];

const int64_t i02 = i/ne10;

for (int ind = tiitg; ind < ne00; ind += tptg) {
((device float *) ((device char *) dst + i*nb1))[ind] =
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}

kernel void kernel_get_rows_f16(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];

const int64_t i02 = i/ne10;

for (int ind = tiitg; ind < ne00; ind += tptg) {
((device float *) ((device char *) dst + i*nb1))[ind] =
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}

#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
#define BLOCK_SIZE_K 32
Expand Down Expand Up @@ -3490,11 +3538,13 @@ typedef void (get_rows_t)(
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1,
uint, uint, uint);

template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
Expand Down
6 changes: 3 additions & 3 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -10363,7 +10363,7 @@ static void ggml_compute_forward_get_rows_q(

dequantize_row_q(
(const void *) ((char *) src0->data + i02*nb02 + r*nb01),
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
(float *) ((char *) dst->data + i*nb1), nc);
}
}

Expand Down Expand Up @@ -10396,7 +10396,7 @@ static void ggml_compute_forward_get_rows_f16(

for (int j = 0; j < nc; ++j) {
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
((float *) ((char *) dst->data + i*nb1))[j] = GGML_FP16_TO_FP32(v);
}
}
}
Expand Down Expand Up @@ -10429,7 +10429,7 @@ static void ggml_compute_forward_get_rows_f32(
const int64_t i02 = i/ne10;

ggml_vec_cpy_f32(nc,
(float *) ((char *) dst->data + i*dst->nb[1]),
(float *) ((char *) dst->data + i*nb1),
(float *) ((char *) src0->data + i02*nb02 + r*nb01));
}
}
Expand Down
19 changes: 10 additions & 9 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,17 +488,18 @@ struct test_get_rows : public test_case {
const int n; // cols
const int m; // rows
const int r; // rows to get
const int b; // batch size

std::string vars() override {
return VARS_TO_STR4(type, n, m, r);
}

test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3)
: type(type), n(n), m(m), r(r) {}
test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1)
: type(type), n(n), m(m), r(r), b(b) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m);
ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r);
ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
ggml_tensor * out = ggml_get_rows(ctx, in, rows);
return out;
}
Expand All @@ -507,11 +508,11 @@ struct test_get_rows : public test_case {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
// rows
std::vector<int> data(r);
for (int i = 0; i < r; i++) {
std::vector<int> data(r*b);
for (int i = 0; i < r*b; i++) {
data[i] = rand() % m;
}
ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int));
ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
} else {
init_tensor_uniform(t);
}
Expand Down Expand Up @@ -1125,8 +1126,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}

for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3));
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
test_cases.emplace_back(new test_get_rows(type, 10, 5, 3, 7));
test_cases.emplace_back(new test_get_rows(type, 16, 5, 3, 7));
}

test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
Expand Down

0 comments on commit 2cbcba8

Please sign in to comment.