Skip to content

Commit d56e061

Browse files
committed
Deadlock avoidance
1 parent d501ef4 commit d56e061

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ jobs:
175175
id: cmake_build
176176
run: |
177177
export CMAKE_PREFIX_PATH=dawn
178-
cmake -B build -DGGML_WEBGPU=ON -DGGML_WEBGPU_SERIALIZE_SUBMIT=ON -DGGML_METAL=OFF -DGGML_BLAS=OFF
178+
cmake -B build -DGGML_WEBGPU=ON -DGGML_METAL=OFF -DGGML_BLAS=OFF
179179
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
180180
181181
- name: Test
@@ -502,7 +502,7 @@ jobs:
502502
id: cmake_build
503503
run: |
504504
export Dawn_DIR=dawn/lib64/cmake/Dawn
505-
cmake -B build -DGGML_WEBGPU=ON -DGGML_WEBGPU_SERIALIZE_SUBMIT=ON
505+
cmake -B build -DGGML_WEBGPU=ON
506506
cmake --build build --config Release -j $(nproc)
507507
508508
- name: Test

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <webgpu/webgpu_cpp.h>
1313

14+
#include <atomic>
1415
#include <condition_variable>
1516
#include <cstring>
1617
#include <iostream>
@@ -65,13 +66,15 @@
6566
# define WEBGPU_WAIT_ANY_TIMEOUT_MS UINT64_MAX
6667
#else
6768
# define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8
68-
# define WEBGPU_WAIT_ANY_TIMEOUT_MS 1
69+
# define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
6970
#endif
7071

7172
/* Constants */
7273

7374
#define WEBGPU_MUL_MAT_WG_SIZE 256
7475
#define WEBGPU_NUM_PARAM_BUFS 32
76+
// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
77+
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
7578
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
7679
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
7780
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
@@ -107,6 +110,11 @@ struct webgpu_pool_bufs {
107110
wgpu::Buffer dev_buf;
108111
};
109112

113+
// The futures to wait on for a single queue submission
114+
struct webgpu_submission_futures {
115+
std::vector<wgpu::FutureWaitInfo> futures;
116+
};
117+
110118
// Holds a pool of parameter buffers for WebGPU operations
111119
struct webgpu_buf_pool {
112120
std::vector<webgpu_pool_bufs> free;
@@ -243,6 +251,7 @@ struct webgpu_context_struct {
243251
uint32_t max_wg_size_x;
244252

245253
std::recursive_mutex mutex;
254+
std::atomic_int inflight_threads = 0;
246255

247256
webgpu_buf_pool param_buf_pool;
248257
webgpu_buf_pool set_rows_error_buf_pool;
@@ -365,12 +374,19 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
365374
/** WebGPU Actions */
366375

367376
// Wait for the queue to finish processing all submitted work
368-
static void ggml_backend_webgpu_wait(webgpu_context & ctx,
369-
std::vector<std::vector<wgpu::FutureWaitInfo>> & futures,
370-
uint64_t timeout_ms = UINT64_MAX) {
377+
static void ggml_backend_webgpu_wait(webgpu_context & ctx,
378+
std::vector<webgpu_submission_futures> & futures,
379+
uint64_t timeout_ms = UINT64_MAX) {
380+
// If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
381+
// inflight_max may be 0, meaning that we must wait on all futures.
382+
int inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / ctx->inflight_threads;
383+
while (futures.size() >= inflight_max && futures.size() > 0) {
384+
ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
385+
futures.erase(futures.begin());
386+
}
371387
size_t i = 0;
372388
while (i < futures.size()) {
373-
auto waitStatus = ctx->instance.WaitAny(futures[i].size(), futures[i].data(), timeout_ms);
389+
auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
374390
switch (waitStatus) {
375391
case wgpu::WaitStatus::Success:
376392
futures.erase(futures.begin() + i);
@@ -424,8 +440,7 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
424440
}
425441
#endif
426442

427-
static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_context ctx,
428-
std::vector<webgpu_command> commands) {
443+
static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector<webgpu_command> commands) {
429444
std::vector<wgpu::CommandBuffer> command_buffers;
430445
std::vector<webgpu_pool_bufs> params_bufs;
431446
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
@@ -484,9 +499,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_conte
484499
if (status != wgpu::MapAsyncStatus::Success) {
485500
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
486501
} else {
487-
const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
502+
const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
488503
// WebGPU timestamps are in ns; convert to ms
489-
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
504+
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
490505
ctx->shader_gpu_time_ms[label] += elapsed_ms;
491506
// We can't unmap in here due to WebGPU reentrancy limitations.
492507
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
@@ -495,7 +510,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_conte
495510
futures.push_back({ f });
496511
}
497512
#endif
498-
return futures;
513+
return { futures };
499514
}
500515

501516
static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx,
@@ -588,7 +603,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
588603
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
589604

590605
webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipeline, params, entries, wg_x);
591-
std::vector<std::vector<wgpu::FutureWaitInfo>> futures = { ggml_backend_webgpu_submit(ctx, { command }) };
606+
std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command }) };
592607
ggml_backend_webgpu_wait(ctx, futures);
593608
}
594609

@@ -1255,25 +1270,31 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
12551270

12561271
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
12571272

1258-
std::vector<webgpu_command> commands;
1259-
std::vector<std::vector<wgpu::FutureWaitInfo>> futures;
1273+
ctx->inflight_threads++;
1274+
1275+
std::vector<webgpu_command> commands;
1276+
std::vector<webgpu_submission_futures> futures;
12601277
for (int i = 0; i < cgraph->n_nodes; i++) {
12611278
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
12621279
commands.push_back(*cmd);
12631280
}
1264-
if (commands.size() >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
1265-
std::vector<wgpu::FutureWaitInfo> new_futures = ggml_backend_webgpu_submit(ctx, commands);
1266-
// check if previous futures have finished
1281+
// compute the batch size based on the number of inflight threads
1282+
int batch_size = std::min(std::max(1, WEBGPU_NUM_PARAM_BUFS / ctx->inflight_threads),
1283+
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
1284+
if (commands.size() >= batch_size) {
1285+
futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
1286+
// Process events and check for completed submissions
1287+
ctx->instance.ProcessEvents();
12671288
ggml_backend_webgpu_wait(ctx, futures, WEBGPU_WAIT_ANY_TIMEOUT_MS);
1268-
futures.push_back({ new_futures });
12691289
commands.clear();
12701290
}
12711291
}
12721292
if (!commands.empty()) {
1273-
std::vector<wgpu::FutureWaitInfo> new_futures = ggml_backend_webgpu_submit(ctx, commands);
1274-
futures.push_back({ new_futures });
1293+
webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
1294+
futures.push_back(new_futures);
12751295
}
12761296
ggml_backend_webgpu_wait(ctx, futures);
1297+
ctx->inflight_threads--;
12771298
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
12781299
return GGML_STATUS_SUCCESS;
12791300
}

0 commit comments

Comments
 (0)