Skip to content

Commit f932bbe

Browse files
ORipplerrick-github
authored andcommitted
Increase performance for Gemma3n models on NVGPUs by enabling CUDA Graph execution (ollama#11525)
* Enable CUDA Graphs for gemma3n. Similar to ggml-org/llama.cpp#14741, though ollama has a slightly different model graph than llama.cpp which requires different workaround checks. * Remove residual check by reshaping differently in gemma3n model This should make the heuristics more robust
1 parent cc23043 commit f932bbe

File tree

5 files changed

+67
-10
lines changed

5 files changed

+67
-10
lines changed

llama/patches/0019-metal-add-mean-kernel-14267.patch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ ggml-ci
1616
2 files changed, 67 insertions(+), 14 deletions(-)
1717

1818
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
19-
index ee4f2dcb..f20f5615 100644
19+
index a9eeebc6..110c9ece 100644
2020
--- a/ggml/src/ggml-metal/ggml-metal.m
2121
+++ b/ggml/src/ggml-metal/ggml-metal.m
2222
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {

llama/patches/0020-CUDA-add-mean-operation-14313.patch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ index 64fb4ff4..5b9a0fe3 100644
5252
static __device__ __forceinline__ float warp_reduce_max(float x) {
5353
#pragma unroll
5454
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
55-
index 4c829153..9e64e5ae 100644
55+
index d6960174..2b9fabf4 100644
5656
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
5757
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
5858
@@ -35,6 +35,7 @@
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
2+
From: Oliver Simons <osimons@nvidia.com>
3+
Date: Tue, 22 Jul 2025 11:02:28 +0200
4+
Subject: [PATCH] Enable CUDA Graphs for gemma3n.
5+
6+
Similar to
7+
https://github.com/ggml-org/llama.cpp/pull/14741,
8+
though ollama has a slightly different model graph
9+
than llama.cpp which requires different workaround
10+
checks.
11+
---
12+
ggml/src/ggml-cuda/ggml-cuda.cu | 16 ++++++++++++----
13+
1 file changed, 12 insertions(+), 4 deletions(-)
14+
15+
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
16+
index 2b9fabf4..28ccf4be 100644
17+
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
18+
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
19+
@@ -2474,6 +2474,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
20+
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
21+
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
22+
23+
+ const std::string gemma3n_per_layer_proj_src1_name = " (reshaped)";
24+
+ const std::string gemma3n_node_name = "node_";
25+
+
26+
for (int i = 0; i < cgraph->n_nodes; i++) {
27+
ggml_tensor * node = cgraph->nodes[i];
28+
29+
@@ -2495,12 +2498,17 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
30+
#endif
31+
}
32+
33+
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
34+
- // disable CUDA graphs for batch size > 1 for now.
35+
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
36+
+ // workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
37+
+ // number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
38+
+ if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
39+
+ && node->ne[2] == 1
40+
+ && node->ne[3] == 1
41+
+ && node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
42+
+ && node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
43+
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
44+
use_cuda_graph = false;
45+
#ifndef NDEBUG
46+
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
47+
+ GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
48+
#endif
49+
}
50+

ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,6 +2474,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
24742474
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
24752475
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
24762476

2477+
const std::string gemma3n_per_layer_proj_src1_name = " (reshaped)";
2478+
const std::string gemma3n_node_name = "node_";
2479+
24772480
for (int i = 0; i < cgraph->n_nodes; i++) {
24782481
ggml_tensor * node = cgraph->nodes[i];
24792482

@@ -2495,12 +2498,17 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
24952498
#endif
24962499
}
24972500

2498-
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2499-
// disable CUDA graphs for batch size > 1 for now.
2500-
// Changes in batch size or context size can cause changes to the grid size of some kernels.
2501+
// workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
2502+
// number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
2503+
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
2504+
&& node->ne[2] == 1
2505+
&& node->ne[3] == 1
2506+
&& node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
2507+
&& node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
2508+
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
25012509
use_cuda_graph = false;
25022510
#ifndef NDEBUG
2503-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2511+
GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
25042512
#endif
25052513
}
25062514

model/models/gemma3n/model_text.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,9 @@ func (a AltUp) Predict(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions
203203
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
204204
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
205205

206-
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
207-
predictions := coefficients.Mulmat(ctx, hiddenStates)
208-
predictions = predictions.Add(ctx, hiddenStates)
209-
return predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
206+
predictions := coefficients.Mulmat(ctx, hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx))
207+
predictions = predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
208+
return predictions.Add(ctx, hiddenStates)
210209
}
211210

212211
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {

0 commit comments

Comments
 (0)