@@ -2849,6 +2849,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
28492849
28502850 // Loop over nodes in GGML graph to obtain info needed for CUDA graph
28512851 cuda_ctx->cuda_graph ->cpy_dest_ptrs .clear ();
2852+ std::uint8_t batch_size_counter = 0 ;
28522853
28532854 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
28542855 ggml_tensor * node = cgraph->nodes [i];
@@ -2872,12 +2873,18 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
28722873 }
28732874
28742875 if (node->op == GGML_OP_ADD && node->src [1 ] && node->src [1 ]->ne [1 ] > 1 ) {
2875- // disable CUDA graphs for batch size > 1 for now.
2876- // Changes in batch size or context size can cause changes to the grid size of some kernels.
2877- use_cuda_graph = false ;
2878- #ifndef NDEBUG
2879- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n " , __func__, node->name , node->ne [0 ], node->ne [1 ], node->ne [2 ], node->ne [3 ]);
2880- #endif
2876+ // disable CUDA graphs for batch size > 1 for now. The heuristic here allows to use CUDA graphs
2877+ // for Gemma3n, which uses a single Matrix-Matrix Addition as part of `project_per_layer_input`, while detecting
2878+ // batched execution for all graphs with >1 GGML_OP_ADD nodes. See also
2879+ // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2880+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2881+ ++batch_size_counter;
2882+ if (batch_size_counter > 1 ) {
2883+ use_cuda_graph = false ;
2884+ #ifndef NDEBUG
2885+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to repeated batch size > 1 [%s] [%ld %ld %ld %ld]\n " , __func__, node->name , node->ne [0 ], node->ne [1 ], node->ne [2 ], node->ne [3 ]);
2886+ #endif
2887+ }
28812888 }
28822889
28832890 if (node->op == GGML_OP_MULTI_ADD && node->ne [1 ] > 1 ) {
0 commit comments