1111
1212#include  < webgpu/webgpu_cpp.h> 
1313
14+ #include  < atomic> 
1415#include  < condition_variable> 
1516#include  < cstring> 
1617#include  < iostream> 
6566#    define  WEBGPU_WAIT_ANY_TIMEOUT_MS        UINT64_MAX
6667#else 
6768#    define  WEBGPU_COMMAND_SUBMIT_BATCH_SIZE  8 
68- #    define  WEBGPU_WAIT_ANY_TIMEOUT_MS        1 
69+ #    define  WEBGPU_WAIT_ANY_TIMEOUT_MS        0 
6970#endif 
7071
7172/*  Constants */ 
7273
7374#define  WEBGPU_MUL_MAT_WG_SIZE                256 
7475#define  WEBGPU_NUM_PARAM_BUFS                 32 
76+ //  Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
77+ #define  WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD   WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
7578#define  WEBGPU_PARAMS_BUF_SIZE_BYTES          128   //  enough for 32 parameters
7679#define  WEBGPU_NUM_SET_ROWS_ERROR_BUFS        32 
7780#define  WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES  4 
@@ -107,6 +110,11 @@ struct webgpu_pool_bufs {
107110    wgpu::Buffer dev_buf;
108111};
109112
113+ //  The futures to wait on for a single queue submission
114+ struct  webgpu_submission_futures  {
115+     std::vector<wgpu::FutureWaitInfo> futures;
116+ };
117+ 
110118//  Holds a pool of parameter buffers for WebGPU operations
111119struct  webgpu_buf_pool  {
112120    std::vector<webgpu_pool_bufs> free;
@@ -243,6 +251,7 @@ struct webgpu_context_struct {
243251    uint32_t  max_wg_size_x;
244252
245253    std::recursive_mutex mutex;
254+     std::atomic_int      inflight_threads = 0 ;
246255
247256    webgpu_buf_pool param_buf_pool;
248257    webgpu_buf_pool set_rows_error_buf_pool;
@@ -365,12 +374,19 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
365374/* * WebGPU Actions */ 
366375
367376//  Wait for the queue to finish processing all submitted work
368- static  void  ggml_backend_webgpu_wait (webgpu_context &                                 ctx,
369-                                      std::vector<std::vector<wgpu::FutureWaitInfo>> & futures,
370-                                      uint64_t                                          timeout_ms = UINT64_MAX) {
377+ static  void  ggml_backend_webgpu_wait (webgpu_context &                         ctx,
378+                                      std::vector<webgpu_submission_futures> & futures,
379+                                      uint64_t                                  timeout_ms = UINT64_MAX) {
380+     //  If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
381+     //  inflight_max may be 0, meaning that we must wait on all futures.
382+     int  inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / ctx->inflight_threads ;
383+     while  (futures.size () >= inflight_max && futures.size () > 0 ) {
384+         ctx->instance .WaitAny (futures[0 ].futures .size (), futures[0 ].futures .data (), UINT64_MAX);
385+         futures.erase (futures.begin ());
386+     }
371387    size_t  i = 0 ;
372388    while  (i < futures.size ()) {
373-         auto  waitStatus = ctx->instance .WaitAny (futures[i].size (), futures[i].data (), timeout_ms);
389+         auto  waitStatus = ctx->instance .WaitAny (futures[i].futures . size (), futures[i]. futures .data (), timeout_ms);
374390        switch  (waitStatus) {
375391            case  wgpu::WaitStatus::Success:
376392                futures.erase (futures.begin () + i);
@@ -424,8 +440,7 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
424440}
425441#endif 
426442
427- static  std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit (webgpu_context              ctx,
428-                                                                     std::vector<webgpu_command> commands) {
443+ static  webgpu_submission_futures ggml_backend_webgpu_submit (webgpu_context ctx, std::vector<webgpu_command> commands) {
429444    std::vector<wgpu::CommandBuffer> command_buffers;
430445    std::vector<webgpu_pool_bufs>    params_bufs;
431446    std::vector<webgpu_pool_bufs>    set_rows_error_bufs;
@@ -484,9 +499,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_conte
484499                if  (status != wgpu::MapAsyncStatus::Success) {
485500                    GGML_LOG_ERROR (" ggml_webgpu: Failed to map timestamp buffer: %s\n " std::string (message).c_str ());
486501                } else  {
487-                     const  uint64_t  * ts_data = (const  uint64_t  *) ts_bufs.host_buf .GetConstMappedRange ();
502+                     const  uint64_t  * ts_data     = (const  uint64_t  *) ts_bufs.host_buf .GetConstMappedRange ();
488503                    //  WebGPU timestamps are in ns; convert to ms
489-                     double  elapsed_ms = double (ts_data[1 ] - ts_data[0 ]) * 1e-6 ;
504+                     double             elapsed_ms = double (ts_data[1 ] - ts_data[0 ]) * 1e-6 ;
490505                    ctx->shader_gpu_time_ms [label] += elapsed_ms;
491506                    //  We can't unmap in here due to WebGPU reentrancy limitations.
492507                    ctx->timestamp_query_buf_pool .free_bufs ({ ts_bufs });
@@ -495,7 +510,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(webgpu_conte
495510        futures.push_back ({ f });
496511    }
497512#endif 
498-     return  futures;
513+     return  {  futures } ;
499514}
500515
501516static  webgpu_command ggml_backend_webgpu_build (webgpu_context &                  ctx,
@@ -588,7 +603,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
588603    uint32_t  wg_x         = ((size + 3 ) + bytes_per_wg - 1 ) / bytes_per_wg;
589604
590605    webgpu_command command = ggml_backend_webgpu_build (ctx, ctx->memset_pipeline , params, entries, wg_x);
591-     std::vector<std::vector<wgpu::FutureWaitInfo> > futures = { ggml_backend_webgpu_submit (ctx, { command }) };
606+     std::vector<webgpu_submission_futures > futures = { ggml_backend_webgpu_submit (ctx, { command }) };
592607    ggml_backend_webgpu_wait (ctx, futures);
593608}
594609
@@ -1255,25 +1270,31 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
12551270
12561271    WEBGPU_CPU_PROFILE_TOTAL_START (graph_compute);
12571272
1258-     std::vector<webgpu_command>                    commands;
1259-     std::vector<std::vector<wgpu::FutureWaitInfo>> futures;
1273+     ctx->inflight_threads ++;
1274+ 
1275+     std::vector<webgpu_command>            commands;
1276+     std::vector<webgpu_submission_futures> futures;
12601277    for  (int  i = 0 ; i < cgraph->n_nodes ; i++) {
12611278        if  (auto  cmd = ggml_webgpu_encode_node (ctx, cgraph->nodes [i])) {
12621279            commands.push_back (*cmd);
12631280        }
1264-         if  (commands.size () >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
1265-             std::vector<wgpu::FutureWaitInfo> new_futures = ggml_backend_webgpu_submit (ctx, commands);
1266-             //  check if previous futures have finished
1281+         //  compute the batch size based on the number of inflight threads
1282+         int  batch_size = std::min (std::max (1 , WEBGPU_NUM_PARAM_BUFS / ctx->inflight_threads ),
1283+                                   WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
1284+         if  (commands.size () >= batch_size) {
1285+             futures.push_back (ggml_backend_webgpu_submit (ctx, commands));
1286+             //  Process events and check for completed submissions
1287+             ctx->instance .ProcessEvents ();
12671288            ggml_backend_webgpu_wait (ctx, futures, WEBGPU_WAIT_ANY_TIMEOUT_MS);
1268-             futures.push_back ({ new_futures });
12691289            commands.clear ();
12701290        }
12711291    }
12721292    if  (!commands.empty ()) {
1273-         std::vector<wgpu::FutureWaitInfo>  new_futures = ggml_backend_webgpu_submit (ctx, commands);
1274-         futures.push_back ({  new_futures } );
1293+         webgpu_submission_futures  new_futures = ggml_backend_webgpu_submit (ctx, commands);
1294+         futures.push_back (new_futures);
12751295    }
12761296    ggml_backend_webgpu_wait (ctx, futures);
1297+     ctx->inflight_threads --;
12771298    WEBGPU_CPU_PROFILE_TOTAL_END (graph_compute, ctx);
12781299    return  GGML_STATUS_SUCCESS;
12791300}
0 commit comments