Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants {
uint32_t orig_ncols;
uint32_t ncols_input;
uint32_t ncols_output;
uint32_t k;
uint32_t nrows;
uint32_t first_pass;
uint32_t last_pass;
Expand Down Expand Up @@ -1673,6 +1674,14 @@ class vk_perf_logger {
timings[name.str()].push_back(time);
return;
}
if (node->op == GGML_OP_TOP_K) {
std::stringstream name;
name << ggml_op_name(node->op) <<
" K=" << node->ne[0] <<
" (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
timings[name.str()].push_back(time);
return;
}
timings[ggml_op_name(node->op)].push_back(time);
}
private:
Expand Down Expand Up @@ -4041,7 +4050,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
sizeof(int) * device->subgroup_size +
2 * sizeof(int) +
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
Expand Down Expand Up @@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
uint32_t nrows = ggml_nrows(src0);
uint32_t k = dst->ne[0];

vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };

// Reserve space for ivec2 per element, double buffered
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
const size_t x_sz = dbl_buf_size * 2;
uint32_t dbl_buf_index = 0;

if (ctx->prealloc_size_x < x_sz) {
ctx->prealloc_size_x = x_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
Expand All @@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
// largest elements. Repeat until we have the top K elements.
// Need to do at least one iteration to write out the results.
bool done_one_iter = false;
uint32_t dbl_buf_index = 0;
size_t dbl_buf_size;
while (num_elements > k || !done_one_iter) {
done_one_iter = true;

// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
// But if K is larger, then we need a larger workgroup
Expand Down Expand Up @@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
// Number of elements remaining after this pass
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);

pc2.ncols_output = num_dst_elements;

if (!done_one_iter) {
// Reserve space for ivec2 per element, double buffered
// K per workgroup per row
dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const size_t x_sz = dbl_buf_size * 2;

if (ctx->prealloc_size_x < x_sz) {
ctx->prealloc_size_x = x_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
}

vk_subbuffer src_buf;
vk_subbuffer dst_buf;

Expand All @@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
if (num_elements > k) {
ggml_vk_sync_buffers(ctx, subctx);
}
done_one_iter = true;
}
ctx->prealloc_x_need_sync = true;
}
Expand Down
19 changes: 12 additions & 7 deletions ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ layout (push_constant) uniform parameter {
uint orig_ncols;
uint ncols_input;
uint ncols_output;
uint k;
uint nrows;
uint first_pass;
uint last_pass;
Expand All @@ -36,15 +37,15 @@ void topk(bool needs_bounds_check, const uint row) {
const uint row_offset = row * p.ncols_input;
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
const uint row_offset = row * p.orig_ncols;
const uint row_offset = row * p.ncols_input;
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
dst_row[col] = ivec2(p.orig_ncols, 0);
}
barrier();

if (p.ncols_output == 1) {
if (p.k == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (col < s) {
Expand Down Expand Up @@ -84,13 +85,17 @@ void topk(bool needs_bounds_check, const uint row) {
}
}

if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
if (col < p.k) {
if (p.last_pass != 0) {
const uint row_offset = row * p.ncols_output;
data_d[row_offset + col] = dst_row[col].x;
if (gl_GlobalInvocationID.x < p.ncols_input) {
const uint row_offset = row * p.k;
data_d[row_offset + col] = dst_row[col].x;
}
} else {
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
data_t[row_offset + col] = dst_row[col];
if (gl_WorkGroupID.x * p.k + col < p.ncols_output) {
const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
data_t[row_offset + col] = dst_row[col];
}
}
}
}
Expand Down
97 changes: 72 additions & 25 deletions ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
uint orig_ncols;
uint ncols_input;
uint ncols_output;
uint k;
uint nrows;
uint first_pass;
uint last_pass;
Expand All @@ -37,6 +38,7 @@ shared int counts[SUBGROUP_SIZE];
shared int sh_min_idx;
shared uint sh_total;
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];

// Map float values to uint such that comparisons still work.
// Positive values set the high bit, negative values are inverted.
Expand All @@ -60,15 +62,15 @@ void topk(const uint row) {
const uint row_offset = row * p.ncols_input;
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
const uint row_offset = row * p.orig_ncols;
const uint row_offset = row * p.ncols_input;
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
}
barrier();

if (p.ncols_output == 1) {
if (p.k == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (tid < s) {
Expand Down Expand Up @@ -98,7 +100,7 @@ void topk(const uint row) {
uint range_max = 0xFF800000;
// How many are above the current range, and how many we need to find.
uint total = 0;
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);

while (mask != 0) {
barrier();
Expand Down Expand Up @@ -139,7 +141,7 @@ void topk(const uint row) {
range_max = range_min + ((min_idx + 1) << shift);
range_min = range_min + (min_idx << shift);

if (total == p.ncols_output) {
if (total == p.k) {
break;
}
total -= counts[min_idx];
Expand All @@ -155,37 +157,82 @@ void topk(const uint row) {
// We need to compact these values to the start of the dst_row array.
// Have each subgroup count how many items it'll store, so other
// subgroups can compute their base offset.
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
uvec4 b = subgroupBallot(top);
uint bit_count = subgroupBallotBitCount(b);
if ((tid % SUBGROUP_SIZE) == 0) {
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
}
barrier();
// Values strictly greater than range_min must be stored. For values equal
// to range_min, there can be ties and it's possible we'll need to store
// an arbitrary subset of them.
// If total == p.k, have a fast path where we don't need to handle ties.
if (total == p.k) {
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
uvec4 b = subgroupBallot(top);
uint bit_count = subgroupBallotBitCount(b);
if ((tid % SUBGROUP_SIZE) == 0) {
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
}
barrier();

uint out_idx = 0;
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
if (i < tid / SUBGROUP_SIZE) {
out_idx += offset_partials[i];
uint out_idx = 0;
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
if (i < tid / SUBGROUP_SIZE) {
out_idx += offset_partials[i];
}
}
}

uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
if (top) {
// TODO: Copy directly to the output?
dst_row[out_idx + bit_count_ex] = v;
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
if (top) {
// TODO: Copy directly to the output?
dst_row[out_idx + bit_count_ex] = v;
}
} else {
bool top = f2ui(intBitsToFloat(v.y)) > range_min;
bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
uvec4 b_top = subgroupBallot(top);
uvec4 b_eq_min = subgroupBallot(eq_min);
uint bit_count_top = subgroupBallotBitCount(b_top);
uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
if ((tid % SUBGROUP_SIZE) == 0) {
offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
}
barrier();

uint out_idx = 0;
uint eq_min_base = 0;
uint eq_min_idx = 0;
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
if (i < tid / SUBGROUP_SIZE) {
out_idx += offset_partials[i];
eq_min_idx += eq_min_partials[i];
}
eq_min_base += offset_partials[i];
}
// range_min values are stored at the end
eq_min_idx += eq_min_base;

uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
if (top) {
// TODO: Copy directly to the output?
dst_row[out_idx + bit_count_ex_top] = v;
}
if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
}
}

barrier();
}

if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
if (tid < p.k) {
if (p.last_pass != 0) {
const uint row_offset = row * p.ncols_output;
data_d[row_offset + tid] = dst_row[tid].x;
if (gl_GlobalInvocationID.x < p.ncols_input) {
const uint row_offset = row * p.k;
data_d[row_offset + tid] = dst_row[tid].x;
}
} else {
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
data_t[row_offset + tid] = dst_row[tid];
if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) {
const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
data_t[row_offset + tid] = dst_row[tid];
}
}
}
}
Expand Down
Loading
Loading