Skip to content

metal : batch rows copy in a single threadgroup #14384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
nth *= 2;
}

nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00);

ggml_metal_kargs_sum_rows args = {
Expand Down Expand Up @@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
nth *= 2;
}

nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);

ggml_metal_kargs_rms_norm args = {
Expand Down Expand Up @@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
nth *= 2;
}

nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);

ggml_metal_kargs_l2_norm args = {
Expand Down Expand Up @@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
nth *= 2;
}

nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);

ggml_metal_kargs_norm args = {
Expand Down Expand Up @@ -4974,8 +4978,39 @@ static bool ggml_metal_encode_node(
default: GGML_ABORT("not implemented");
}

GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);

// TODO: support
//const int32_t nk00 = ne00/ggml_blck_size(dst->type);
const int32_t nk00 = ne00;

int nth = 32; // SIMD width

while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}

nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);

// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;

// TODO: relax this constraint in the future
if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
if (nth > nk00) {
nrptg = (nth + nk00 - 1)/nk00;
nth = nk00;

if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
nrptg--;
}
}
}

nth = MIN(nth, nk00);

ggml_metal_kargs_cpy args = {
/*.ne00 =*/ ne00,
/*.ne00 =*/ nk00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
Expand All @@ -4998,11 +5033,7 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];

GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];

[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
} break;
case GGML_OP_SET:
{
Expand Down
11 changes: 8 additions & 3 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -4306,11 +4306,16 @@ kernel void kernel_cpy(
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
ushort3 tptg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;

if (i01 >= args.ne01) {
return;
}

const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;

Expand All @@ -4321,7 +4326,7 @@ kernel void kernel_cpy(

device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);

for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
}
Expand Down
Loading