@@ -501,6 +501,8 @@ struct vk_device_struct {
501
501
502
502
ggml_backend_buffer_type buffer_type;
503
503
504
+ bool disable_fusion;
505
+
504
506
#ifdef GGML_VULKAN_MEMORY_DEBUG
505
507
std::unique_ptr<vk_memory_logger> memory_logger;
506
508
#endif
@@ -1091,8 +1093,8 @@ static size_t vk_skip_checks;
1091
1093
static size_t vk_output_tensor;
1092
1094
1093
1095
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
1094
- static void ggml_vk_check_results_0(ggml_tensor * tensor );
1095
- static void ggml_vk_check_results_1(ggml_tensor * tensor );
1096
+ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx );
1097
+ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx );
1096
1098
#endif
1097
1099
1098
1100
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -3507,6 +3509,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
3507
3509
3508
3510
device->idx = idx;
3509
3511
3512
+ device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3513
+
3510
3514
return device;
3511
3515
}
3512
3516
@@ -7654,8 +7658,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
7654
7658
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
7655
7659
}
7656
7660
7657
- static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7658
- float * op_params = (float *)dst->op_params;
7661
+ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
7659
7662
const uint32_t src0_type_size = ggml_type_size(src0->type);
7660
7663
const uint32_t src1_type_size = ggml_type_size(src1->type);
7661
7664
const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -8885,7 +8888,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
8885
8888
}
8886
8889
}
8887
8890
8888
- static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
8891
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
8889
8892
8890
8893
// Returns true if node has enqueued work into the queue, false otherwise
8891
8894
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
@@ -9146,9 +9149,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9146
9149
// fused rms_norm + mul
9147
9150
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9148
9151
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
9149
- ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
9152
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
9150
9153
} else {
9151
- ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
9154
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
9152
9155
}
9153
9156
break;
9154
9157
case GGML_OP_RMS_NORM_BACK:
@@ -9308,7 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9308
9311
9309
9312
ctx->compute_ctx.reset();
9310
9313
9311
- bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
9314
+ bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
9312
9315
if (!ok) {
9313
9316
if (node->op == GGML_OP_UNARY) {
9314
9317
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@@ -9323,7 +9326,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9323
9326
return true;
9324
9327
}
9325
9328
9326
- static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9329
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9330
+ GGML_UNUSED(cgraph);
9327
9331
ggml_backend_buffer * buf = nullptr;
9328
9332
9329
9333
switch (tensor->op) {
@@ -9433,7 +9437,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9433
9437
// Only run if ctx hasn't been submitted yet
9434
9438
if (!subctx->seqs.empty()) {
9435
9439
#ifdef GGML_VULKAN_CHECK_RESULTS
9436
- ggml_vk_check_results_0(tensor );
9440
+ ggml_vk_check_results_0(ctx, cgraph, tensor_idx );
9437
9441
use_fence = true;
9438
9442
#endif
9439
9443
@@ -9453,7 +9457,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9453
9457
ggml_vk_wait_for_fence(ctx);
9454
9458
}
9455
9459
#ifdef GGML_VULKAN_CHECK_RESULTS
9456
- ggml_vk_check_results_1(tensor );
9460
+ ggml_vk_check_results_1(ctx, cgraph, tensor_idx );
9457
9461
#endif
9458
9462
}
9459
9463
@@ -9900,6 +9904,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
9900
9904
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
9901
9905
}
9902
9906
9907
+ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
9908
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
9909
+ return false;
9910
+ }
9911
+
9912
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
9913
+ // additional constraints specific to this fusion
9914
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
9915
+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9916
+
9917
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
9918
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
9919
+ // rms_norm only supports f32
9920
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
9921
+ mul->src[1]->type != GGML_TYPE_F32 ||
9922
+ mul->type != GGML_TYPE_F32) {
9923
+ return false;
9924
+ }
9925
+ // if rms_norm is the B operand, then we don't handle broadcast
9926
+ if (rms_norm == mul->src[1] &&
9927
+ mul->src[0]->ne[1] != rms_norm->ne[1]) {
9928
+ return false;
9929
+ }
9930
+ // rms_norm shader assumes contiguous rows
9931
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
9932
+ return false;
9933
+ }
9934
+ }
9935
+ return true;
9936
+ }
9937
+
9903
9938
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
9904
9939
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
9905
9940
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9913,7 +9948,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9913
9948
9914
9949
uint64_t total_mat_mul_bytes = 0;
9915
9950
for (int i = 0; i < cgraph->n_nodes; i++) {
9916
- if (ggml_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9951
+ if (!ctx->device->disable_fusion && ggml_vk_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9917
9952
ctx->num_additional_fused_ops = 1;
9918
9953
}
9919
9954
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -9983,7 +10018,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9983
10018
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9984
10019
}
9985
10020
9986
- if (ggml_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10021
+ if (!ctx->device->disable_fusion && ggml_vk_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9987
10022
ctx->num_additional_fused_ops = 1;
9988
10023
}
9989
10024
@@ -10760,11 +10795,21 @@ void * comp_result;
10760
10795
size_t comp_size;
10761
10796
size_t comp_nb[GGML_MAX_DIMS];
10762
10797
size_t check_counter = 0;
10763
- static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10798
+ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
10799
+ ggml_tensor * tensor = cgraph->nodes[tensor_idx];
10764
10800
if (tensor->op == GGML_OP_TRANSPOSE) {
10765
10801
return;
10766
10802
}
10767
10803
10804
+ bool fused_rms_norm_mul = false;
10805
+ int rms_norm_idx = -1;
10806
+ if (ctx->num_additional_fused_ops == 1 &&
10807
+ tensor->op == GGML_OP_RMS_NORM &&
10808
+ cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
10809
+ fused_rms_norm_mul = true;
10810
+ tensor = cgraph->nodes[tensor_idx + 1];
10811
+ }
10812
+
10768
10813
check_counter++;
10769
10814
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
10770
10815
return;
@@ -10792,6 +10837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10792
10837
10793
10838
for (int i = 0; i < 6; i++) {
10794
10839
ggml_tensor * srci = tensor->src[i];
10840
+ if (fused_rms_norm_mul) {
10841
+ rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
10842
+ ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
10843
+ switch (i) {
10844
+ case 0: srci = rms_norm->src[0]; break;
10845
+ case 1: srci = tensor->src[1 - rms_norm_idx]; break;
10846
+ default: continue;
10847
+ }
10848
+ }
10795
10849
if (srci == nullptr) {
10796
10850
continue;
10797
10851
}
@@ -10849,7 +10903,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10849
10903
} else if (tensor->op == GGML_OP_SUB) {
10850
10904
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
10851
10905
} else if (tensor->op == GGML_OP_MUL) {
10852
- tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
10906
+ if (fused_rms_norm_mul) {
10907
+ tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
10908
+ tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
10909
+ } else {
10910
+ tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
10911
+ }
10853
10912
} else if (tensor->op == GGML_OP_DIV) {
10854
10913
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
10855
10914
} else if (tensor->op == GGML_OP_CONCAT) {
@@ -11040,10 +11099,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
11040
11099
GGML_ABORT("fatal error");
11041
11100
}
11042
11101
11043
- ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
11044
- ggml_build_forward_expand(cgraph , tensor_clone);
11102
+ ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
11103
+ ggml_build_forward_expand(cgraph_cpu , tensor_clone);
11045
11104
11046
- ggml_graph_compute_with_ctx(ggml_ctx, cgraph , 8);
11105
+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu , 8);
11047
11106
11048
11107
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
11049
11108
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@@ -11066,10 +11125,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
11066
11125
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
11067
11126
}
11068
11127
11069
- static void ggml_vk_check_results_1(ggml_tensor * tensor) {
11128
+ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
11129
+ ggml_tensor * tensor = cgraph->nodes[tensor_idx];
11070
11130
if (tensor->op == GGML_OP_TRANSPOSE) {
11071
11131
return;
11072
11132
}
11133
+ bool fused_rms_norm_mul = false;
11134
+ if (ctx->num_additional_fused_ops == 1 &&
11135
+ tensor->op == GGML_OP_RMS_NORM &&
11136
+ cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
11137
+ fused_rms_norm_mul = true;
11138
+ tensor = cgraph->nodes[tensor_idx + 1];
11139
+ }
11140
+
11073
11141
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
11074
11142
return;
11075
11143
}
0 commit comments