Skip to content

Commit e592be1

Browse files
authored
vulkan: fix rms_norm+mul fusion (#14545)
The fused operation was grabbing the epsilon value from the wrong place. Add an env var to disable fusion. Add some missing checks for supported shapes/types. Handle fused rms_norm+mul in check_results.
1 parent a0374a6 commit e592be1

File tree

2 files changed

+88
-24
lines changed

2 files changed

+88
-24
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ struct vk_device_struct {
501501

502502
ggml_backend_buffer_type buffer_type;
503503

504+
bool disable_fusion;
505+
504506
#ifdef GGML_VULKAN_MEMORY_DEBUG
505507
std::unique_ptr<vk_memory_logger> memory_logger;
506508
#endif
@@ -1091,8 +1093,8 @@ static size_t vk_skip_checks;
10911093
static size_t vk_output_tensor;
10921094

10931095
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
1094-
static void ggml_vk_check_results_0(ggml_tensor * tensor);
1095-
static void ggml_vk_check_results_1(ggml_tensor * tensor);
1096+
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
1097+
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
10961098
#endif
10971099

10981100
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -3507,6 +3509,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
35073509

35083510
device->idx = idx;
35093511

3512+
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3513+
35103514
return device;
35113515
}
35123516

@@ -7654,8 +7658,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
76547658
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
76557659
}
76567660

7657-
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7658-
float * op_params = (float *)dst->op_params;
7661+
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
76597662
const uint32_t src0_type_size = ggml_type_size(src0->type);
76607663
const uint32_t src1_type_size = ggml_type_size(src1->type);
76617664
const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -8885,7 +8888,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
88858888
}
88868889
}
88878890

8888-
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
8891+
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
88898892

88908893
// Returns true if node has enqueued work into the queue, false otherwise
88918894
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
@@ -9146,9 +9149,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91469149
// fused rms_norm + mul
91479150
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
91489151
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
9149-
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
9152+
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
91509153
} else {
9151-
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
9154+
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
91529155
}
91539156
break;
91549157
case GGML_OP_RMS_NORM_BACK:
@@ -9308,7 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93089311

93099312
ctx->compute_ctx.reset();
93109313

9311-
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
9314+
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
93129315
if (!ok) {
93139316
if (node->op == GGML_OP_UNARY) {
93149317
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@@ -9323,7 +9326,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93239326
return true;
93249327
}
93259328

9326-
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9329+
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9330+
GGML_UNUSED(cgraph);
93279331
ggml_backend_buffer * buf = nullptr;
93289332

93299333
switch (tensor->op) {
@@ -9433,7 +9437,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
94339437
// Only run if ctx hasn't been submitted yet
94349438
if (!subctx->seqs.empty()) {
94359439
#ifdef GGML_VULKAN_CHECK_RESULTS
9436-
ggml_vk_check_results_0(tensor);
9440+
ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
94379441
use_fence = true;
94389442
#endif
94399443

@@ -9453,7 +9457,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
94539457
ggml_vk_wait_for_fence(ctx);
94549458
}
94559459
#ifdef GGML_VULKAN_CHECK_RESULTS
9456-
ggml_vk_check_results_1(tensor);
9460+
ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
94579461
#endif
94589462
}
94599463

@@ -9900,6 +9904,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
99009904
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
99019905
}
99029906

9907+
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
9908+
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
9909+
return false;
9910+
}
9911+
9912+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
9913+
// additional constraints specific to this fusion
9914+
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
9915+
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9916+
9917+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
9918+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
9919+
// rms_norm only supports f32
9920+
if (mul->src[0]->type != GGML_TYPE_F32 ||
9921+
mul->src[1]->type != GGML_TYPE_F32 ||
9922+
mul->type != GGML_TYPE_F32) {
9923+
return false;
9924+
}
9925+
// if rms_norm is the B operand, then we don't handle broadcast
9926+
if (rms_norm == mul->src[1] &&
9927+
mul->src[0]->ne[1] != rms_norm->ne[1]) {
9928+
return false;
9929+
}
9930+
// rms_norm shader assumes contiguous rows
9931+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
9932+
return false;
9933+
}
9934+
}
9935+
return true;
9936+
}
9937+
99039938
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
99049939
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
99059940
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9913,7 +9948,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
99139948

99149949
uint64_t total_mat_mul_bytes = 0;
99159950
for (int i = 0; i < cgraph->n_nodes; i++) {
9916-
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9951+
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
99179952
ctx->num_additional_fused_ops = 1;
99189953
}
99199954
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -9983,7 +10018,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
998310018
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
998410019
}
998510020

9986-
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10021+
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
998710022
ctx->num_additional_fused_ops = 1;
998810023
}
998910024

@@ -10760,11 +10795,21 @@ void * comp_result;
1076010795
size_t comp_size;
1076110796
size_t comp_nb[GGML_MAX_DIMS];
1076210797
size_t check_counter = 0;
10763-
static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10798+
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
10799+
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
1076410800
if (tensor->op == GGML_OP_TRANSPOSE) {
1076510801
return;
1076610802
}
1076710803

10804+
bool fused_rms_norm_mul = false;
10805+
int rms_norm_idx = -1;
10806+
if (ctx->num_additional_fused_ops == 1 &&
10807+
tensor->op == GGML_OP_RMS_NORM &&
10808+
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
10809+
fused_rms_norm_mul = true;
10810+
tensor = cgraph->nodes[tensor_idx + 1];
10811+
}
10812+
1076810813
check_counter++;
1076910814
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
1077010815
return;
@@ -10792,6 +10837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1079210837

1079310838
for (int i = 0; i < 6; i++) {
1079410839
ggml_tensor * srci = tensor->src[i];
10840+
if (fused_rms_norm_mul) {
10841+
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
10842+
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
10843+
switch (i) {
10844+
case 0: srci = rms_norm->src[0]; break;
10845+
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
10846+
default: continue;
10847+
}
10848+
}
1079510849
if (srci == nullptr) {
1079610850
continue;
1079710851
}
@@ -10849,7 +10903,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1084910903
} else if (tensor->op == GGML_OP_SUB) {
1085010904
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
1085110905
} else if (tensor->op == GGML_OP_MUL) {
10852-
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
10906+
if (fused_rms_norm_mul) {
10907+
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
10908+
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
10909+
} else {
10910+
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
10911+
}
1085310912
} else if (tensor->op == GGML_OP_DIV) {
1085410913
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
1085510914
} else if (tensor->op == GGML_OP_CONCAT) {
@@ -11040,10 +11099,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1104011099
GGML_ABORT("fatal error");
1104111100
}
1104211101

11043-
ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
11044-
ggml_build_forward_expand(cgraph, tensor_clone);
11102+
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
11103+
ggml_build_forward_expand(cgraph_cpu, tensor_clone);
1104511104

11046-
ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
11105+
ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
1104711106

1104811107
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
1104911108
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@@ -11066,10 +11125,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1106611125
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
1106711126
}
1106811127

11069-
static void ggml_vk_check_results_1(ggml_tensor * tensor) {
11128+
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
11129+
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
1107011130
if (tensor->op == GGML_OP_TRANSPOSE) {
1107111131
return;
1107211132
}
11133+
bool fused_rms_norm_mul = false;
11134+
if (ctx->num_additional_fused_ops == 1 &&
11135+
tensor->op == GGML_OP_RMS_NORM &&
11136+
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
11137+
fused_rms_norm_mul = true;
11138+
tensor = cgraph->nodes[tensor_idx + 1];
11139+
}
11140+
1107311141
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
1107411142
return;
1107511143
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,10 +2583,6 @@ struct test_rms_norm_mul : public test_case {
25832583
}
25842584
}
25852585

2586-
double max_nmse_err() override {
2587-
return 1e-6;
2588-
}
2589-
25902586
float grad_eps() override {
25912587
return 1.0f;
25922588
}
@@ -5058,7 +5054,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
50585054
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
50595055
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
50605056
}
5061-
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
5057+
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
50625058
test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
50635059
}
50645060

0 commit comments

Comments
 (0)