Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 27, 2024
1 parent 9d8f12c commit 444113b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 24 deletions.
32 changes: 15 additions & 17 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3378,27 +3378,25 @@ kernel void kernel_concat(
device const char * src;

int64_t o[4] = {0, 0, 0, 0};

if (dim > 0 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
src = src0;
o[dim] = 0;
} else {
src = src1;
o[dim] = dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03);
}
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));

//if (dim > 0) {
// if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
// src = src0;
// o[dim] = 0;
// b[dim] = dim == 1 ? nb01 : (dim == 2 ? nb02 : nb03);
// } else {
// src = src1;
// o[dim] = dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03);
// b[dim] = dim == 1 ? nb11 : (dim == 2 ? nb12 : nb13);
// }
//}

for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
if (dim == 0) {
if (i0 < ne00) {
src = src0;
o[dim] = 0;
} else {
src = src1;
o[dim] = ne00;
}
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
}

device const float * x = (device const float *)(src + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
device const float * x = (device const float *)(src + (i3 - o[3])*b[3] + (i2 - o[2])*b[2] + (i1 - o[1])*b[1] + (i0 - o[0])*b[0]);
device float * y = (device float *)(dst + (i3 )*nb3 + (i2 )*nb2 + (i1 )*nb1 + (i0 )*nb0);

*y = *x;
Expand Down
12 changes: 5 additions & 7 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -10985,25 +10985,23 @@ static void ggml_compute_forward_concat_f32(

GGML_ASSERT(dim >= 0 && dim < 4);

const char * src;
const float * x;

int64_t o[4] = {0, 0, 0, 0};
o[dim] = src0->ne[dim];

// TODO: smarter multi-theading
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ith; i2 < ne2; i2 += nth) {
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
src = (const char *) src0->data;
o[dim] = 0;
x = (const float *) (src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
} else {
src = (const char *) src1->data;
o[dim] = src0->ne[dim];
x = (const float *) (src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
}

const float * x = (const float *)( src + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13);
float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3);
float * y = ( float *)((char *)dst->data + (i0 ) * nb0 + (i1 ) * nb1 + (i2 ) * nb2 + (i3 ) * nb3);

*y = *x;
}
Expand Down

0 comments on commit 444113b

Please sign in to comment.