6565#    define  WEBGPU_WAIT_ANY_TIMEOUT_MS        UINT64_MAX
6666#else 
6767#    define  WEBGPU_COMMAND_SUBMIT_BATCH_SIZE  8 
68- #    define  WEBGPU_WAIT_ANY_TIMEOUT_MS        1 
68+ #    define  WEBGPU_WAIT_ANY_TIMEOUT_MS        0 
6969#endif 
7070
7171/*  Constants */ 
7272
7373#define  WEBGPU_MUL_MAT_WG_SIZE                256 
7474#define  WEBGPU_NUM_PARAM_BUFS                 32 
75+ //  Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
76+ #define  WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD   WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
7577#define  WEBGPU_PARAMS_BUF_SIZE_BYTES          128   //  enough for 32 parameters
7678#define  WEBGPU_NUM_SET_ROWS_ERROR_BUFS        32 
7779#define  WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES  4 
@@ -107,6 +109,11 @@ struct webgpu_pool_bufs {
107109    wgpu::Buffer dev_buf;
108110};
109111
112+ //  The futures to wait on for a single queue submission
113+ struct  webgpu_submission_futures  {
114+     std::vector<wgpu::FutureWaitInfo> futures;
115+ };
116+ 
110117//  Holds a pool of parameter buffers for WebGPU operations
111118struct  webgpu_buf_pool  {
112119    std::vector<webgpu_pool_bufs> free;
@@ -243,6 +250,7 @@ struct webgpu_context_struct {
243250    uint32_t  max_wg_size_x;
244251
245252    std::recursive_mutex mutex;
253+     std::atomic_int      inflight_threads = 0 ;
246254
247255    webgpu_buf_pool param_buf_pool;
248256    webgpu_buf_pool set_rows_error_buf_pool;
@@ -365,12 +373,19 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
365373/* * WebGPU Actions */ 
366374
367375//  Wait for the queue to finish processing all submitted work
368- static  void  ggml_backend_webgpu_wait (webgpu_context &                                 ctx,
369-                                      std::vector<std::vector<wgpu::FutureWaitInfo>> & futures,
370-                                      uint64_t                                          timeout_ms = UINT64_MAX) {
376+ static  void  ggml_backend_webgpu_wait (webgpu_context &                         ctx,
377+                                      std::vector<webgpu_submission_futures> & futures,
378+                                      uint64_t                                  timeout_ms = UINT64_MAX) {
379+     //  If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
380+     //  inflight_max may be 0, meaning that we must wait on all futures.
381+     int  inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / ctx->inflight_threads ;
382+     while  (futures.size () >= inflight_max && futures.size () > 0 ) {
383+         ctx->instance .WaitAny (futures[0 ].futures .size (), futures[0 ].futures .data (), UINT64_MAX);
384+         futures.erase (futures.begin ());
385+     }
371386    size_t  i = 0 ;
372387    while  (i < futures.size ()) {
373-         auto  waitStatus = ctx->instance .WaitAny (futures[i].size (), futures[i].data (), timeout_ms);
388+         auto  waitStatus = ctx->instance .WaitAny (futures[i].futures . size (), futures[i]. futures .data (), timeout_ms);
374389        switch  (waitStatus) {
375390            case  wgpu::WaitStatus::Success:
376391                futures.erase (futures.begin () + i);
@@ -424,8 +439,7 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
424439}
425440#endif 
426441
427- static  std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit (webgpu_context              ctx,
428-                                                                     std::vector<webgpu_command> commands) {
442+ static  webgpu_submission_futures ggml_backend_webgpu_submit (webgpu_context ctx, std::vector<webgpu_command> commands) {
429443    std::vector<wgpu::CommandBuffer> command_buffers;
430444    std::vector<webgpu_pool_bufs>    params_bufs;
431445    std::vector<webgpu_pool_bufs>    set_rows_error_bufs;
@@ -484,9 +498,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_conte
484498                if  (status != wgpu::MapAsyncStatus::Success) {
485499                    GGML_LOG_ERROR (" ggml_webgpu: Failed to map timestamp buffer: %s\n " std::string (message).c_str ());
486500                } else  {
487-                     const  uint64_t  * ts_data = (const  uint64_t  *) ts_bufs.host_buf .GetConstMappedRange ();
501+                     const  uint64_t  * ts_data     = (const  uint64_t  *) ts_bufs.host_buf .GetConstMappedRange ();
488502                    //  WebGPU timestamps are in ns; convert to ms
489-                     double  elapsed_ms = double (ts_data[1 ] - ts_data[0 ]) * 1e-6 ;
503+                     double             elapsed_ms = double (ts_data[1 ] - ts_data[0 ]) * 1e-6 ;
490504                    ctx->shader_gpu_time_ms [label] += elapsed_ms;
491505                    //  We can't unmap in here due to WebGPU reentrancy limitations.
492506                    ctx->timestamp_query_buf_pool .free_bufs ({ ts_bufs });
@@ -495,7 +509,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_conte
495509        futures.push_back ({ f });
496510    }
497511#endif 
498-     return  futures;
512+     return  {  futures } ;
499513}
500514
501515static  webgpu_command ggml_backend_webgpu_build (webgpu_context &                  ctx,
@@ -588,7 +602,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
588602    uint32_t  wg_x         = ((size + 3 ) + bytes_per_wg - 1 ) / bytes_per_wg;
589603
590604    webgpu_command command = ggml_backend_webgpu_build (ctx, ctx->memset_pipeline , params, entries, wg_x);
591-     std::vector<std::vector<wgpu::FutureWaitInfo> > futures = { ggml_backend_webgpu_submit (ctx, { command }) };
605+     std::vector<webgpu_submission_futures > futures = { ggml_backend_webgpu_submit (ctx, { command }) };
592606    ggml_backend_webgpu_wait (ctx, futures);
593607}
594608
@@ -1255,25 +1269,31 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
12551269
12561270    WEBGPU_CPU_PROFILE_TOTAL_START (graph_compute);
12571271
1258-     std::vector<webgpu_command>                    commands;
1259-     std::vector<std::vector<wgpu::FutureWaitInfo>> futures;
1272+     ctx->inflight_threads ++;
1273+ 
1274+     std::vector<webgpu_command>            commands;
1275+     std::vector<webgpu_submission_futures> futures;
12601276    for  (int  i = 0 ; i < cgraph->n_nodes ; i++) {
12611277        if  (auto  cmd = ggml_webgpu_encode_node (ctx, cgraph->nodes [i])) {
12621278            commands.push_back (*cmd);
12631279        }
1264-         if  (commands.size () >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
1265-             std::vector<wgpu::FutureWaitInfo> new_futures = ggml_backend_webgpu_submit (ctx, commands);
1266-             //  check if previous futures have finished
1280+         //  compute the batch size based on the number of inflight threads
1281+         int  batch_size = std::min (std::max (1 , WEBGPU_NUM_PARAM_BUFS / ctx->inflight_threads ),
1282+                                   WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
1283+         if  (commands.size () >= batch_size) {
1284+             futures.push_back (ggml_backend_webgpu_submit (ctx, commands));
1285+             //  Process events and check for completed submissions
1286+             ctx->instance .ProcessEvents ();
12671287            ggml_backend_webgpu_wait (ctx, futures, WEBGPU_WAIT_ANY_TIMEOUT_MS);
1268-             futures.push_back ({ new_futures });
12691288            commands.clear ();
12701289        }
12711290    }
12721291    if  (!commands.empty ()) {
1273-         std::vector<wgpu::FutureWaitInfo>  new_futures = ggml_backend_webgpu_submit (ctx, commands);
1274-         futures.push_back ({  new_futures } );
1292+         webgpu_submission_futures  new_futures = ggml_backend_webgpu_submit (ctx, commands);
1293+         futures.push_back (new_futures);
12751294    }
12761295    ggml_backend_webgpu_wait (ctx, futures);
1296+     ctx->inflight_threads --;
12771297    WEBGPU_CPU_PROFILE_TOTAL_END (graph_compute, ctx);
12781298    return  GGML_STATUS_SUCCESS;
12791299}
0 commit comments