@@ -369,9 +369,8 @@ struct vk_device_struct {
369369 bool subgroup_add;
370370 bool subgroup_shuffle;
371371
372- bool atomic_float_add;
373372 bool add_rms_fusion;
374- uint32_t atomic_binding_alignment ;
373+ uint32_t partials_binding_alignment ;
375374
376375 bool integer_dot_product;
377376
@@ -476,6 +475,8 @@ struct vk_device_struct {
476475 vk_pipeline pipeline_group_norm_f32;
477476 vk_pipeline pipeline_rms_norm_f32;
478477 vk_pipeline pipeline_rms_norm_mul_f32;
478+ vk_pipeline pipeline_rms_norm_partials_f32;
479+ vk_pipeline pipeline_rms_norm_mul_partials_f32;
479480 vk_pipeline pipeline_rms_norm_back_f32;
480481 vk_pipeline pipeline_l2_norm_f32;
481482
@@ -1170,13 +1171,13 @@ struct ggml_backend_vk_context {
11701171
11711172 size_t semaphore_idx, event_idx;
11721173 ggml_vk_garbage_collector gc;
1173- size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset ;
1174- vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add ;
1174+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset ;
1175+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials ;
11751176 vk::Fence fence, almost_ready_fence;
11761177 bool almost_ready_fence_pending {};
11771178 // Set before op_add and unset after op_rms_norm to indicate that the add should
1178- // use atomics to accumulate the square of the vector components
1179- bool do_add_rms_atomic ;
1179+ // write partial sums to accumulate the square of the vector components
1180+ bool do_add_rms_partials ;
11801181
11811182 vk_buffer buffer_pool[MAX_VK_BUFFERS];
11821183
@@ -2939,8 +2940,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
29392940
29402941 ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29412942 ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2943+
29422944 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
29432945 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
2946+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
2947+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
2948+
29442949 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29452950 ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29462951
@@ -3298,7 +3303,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
32983303 device->coopmat_support = false;
32993304 device->integer_dot_product = false;
33003305 bool bfloat16_support = false;
3301- bool atomic_float_support = false;
33023306
33033307 for (const auto& properties : ext_props) {
33043308 if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3338,8 +3342,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
33383342 !getenv("GGML_VK_DISABLE_BFLOAT16")) {
33393343 bfloat16_support = true;
33403344#endif
3341- } else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3342- atomic_float_support = true;
33433345 }
33443346 }
33453347
@@ -3556,14 +3558,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
35563558 device_extensions.push_back("VK_KHR_shader_integer_dot_product");
35573559 }
35583560
3559- VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3560- atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3561- if (atomic_float_support) {
3562- last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3563- last_struct = (VkBaseOutStructure *)&atomic_float_features;
3564- device_extensions.push_back("VK_EXT_shader_atomic_float");
3565- }
3566-
35673561 vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
35683562
35693563 device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3575,7 +3569,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
35753569#endif
35763570
35773571 device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3578- device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
35793572
35803573 if (device->subgroup_size_control) {
35813574 device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -3891,9 +3884,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
38913884 device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
38923885
38933886 device->add_rms_fusion = !device->disable_fusion &&
3894- device->subgroup_add &&
3895- device->atomic_float_add;
3896- device->atomic_binding_alignment =
3887+ device->subgroup_add;
3888+ device->partials_binding_alignment =
38973889 std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
38983890
38993891 return device;
@@ -6927,7 +6919,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69276919 switch (op) {
69286920 case GGML_OP_ADD:
69296921 {
6930- if (ctx->do_add_rms_atomic ) {
6922+ if (ctx->do_add_rms_partials ) {
69316923 auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
69326924 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
69336925 } else {
@@ -7051,7 +7043,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
70517043 return nullptr;
70527044 case GGML_OP_RMS_NORM:
70537045 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7054- return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7046+ if (ctx->do_add_rms_partials) {
7047+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
7048+ } else {
7049+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7050+ }
70557051 }
70567052 return nullptr;
70577053 case GGML_OP_RMS_NORM_BACK:
@@ -7534,7 +7530,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
75347530 }
75357531 } break;
75367532 case GGML_OP_RMS_NORM:
7537- if (ctx->do_add_rms_atomic ) {
7533+ if (ctx->do_add_rms_partials ) {
75387534 // Run one element per thread, 128 threads per workgroup
75397535 elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
75407536 } else {
@@ -7688,8 +7684,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
76887684 }
76897685
76907686 if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7691- vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7692- size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7687+ vk_buffer d_A = ctx->prealloc_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
7688+ size_t a_buf_offset = ctx->prealloc_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
76937689 ggml_vk_sync_buffers(subctx);
76947690 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
76957691 { vk_subbuffer{ d_X, x_buf_offset, x_sz },
@@ -7805,7 +7801,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
78057801 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
78067802 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
78077803 0,
7808- 0.0f, 0.0f, ctx->do_add_rms_atomic ,
7804+ 0.0f, 0.0f, ctx->do_add_rms_partials ,
78097805 }, dryrun);
78107806}
78117807
@@ -8257,23 +8253,38 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
82578253 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
82588254}
82598255
8256+ static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8257+ const uint32_t ne = (uint32_t)node->ne[0];
8258+ const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
8259+ const uint32_t num_partials = CEIL_DIV(ne, denom);
8260+ return num_partials;
8261+ }
8262+
8263+ static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8264+ const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
8265+ const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
8266+ return num_bytes;
8267+ }
8268+
82608269static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
82618270 const uint32_t src0_type_size = ggml_type_size(src0->type);
82628271 const uint32_t src1_type_size = ggml_type_size(src1->type);
82638272 const uint32_t dst_type_size = ggml_type_size(dst->type);
82648273
8274+ uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
8275+
82658276 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
82668277 (uint32_t)ggml_nelements(src0),
82678278 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
82688279 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
82698280 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
82708281 0,
8271- op_params[0], 0.0f, ctx->do_add_rms_atomic ,
8282+ op_params[0], 0.0f, (int32_t)param3 ,
82728283 }, dryrun);
82738284
8274- if (ctx->do_add_rms_atomic ) {
8275- ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment ;
8276- ctx->do_add_rms_atomic = false;
8285+ if (ctx->do_add_rms_partials ) {
8286+ ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size( ctx, src0) ;
8287+ ctx->do_add_rms_partials = false;
82778288 }
82788289}
82798290
@@ -9552,13 +9563,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
95529563 }
95539564 ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
95549565 }
9555- if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add ->size < ctx->prealloc_size_atomic_add )) {
9556- VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size : " << ctx->prealloc_atomic_add << ")");
9566+ if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials ->size < ctx->prealloc_size_add_rms_partials )) {
9567+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size : " << ctx->prealloc_add_rms_partials << ")");
95579568 // Resize buffer
9558- if (ctx->prealloc_atomic_add != nullptr) {
9559- ggml_vk_destroy_buffer(ctx->prealloc_atomic_add );
9569+ if (ctx->prealloc_add_rms_partials != nullptr) {
9570+ ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials );
95609571 }
9561- ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add );
9572+ ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials );
95629573 }
95639574}
95649575
@@ -9622,9 +9633,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
96229633 ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
96239634 ctx->device->add_rms_fusion) {
96249635 if (dryrun) {
9625- ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment ;
9636+ ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size( ctx, cgraph->nodes[node_idx]) ;
96269637 }
9627- ctx->do_add_rms_atomic = true;
9638+ ctx->do_add_rms_partials = true;
96289639 }
96299640 break;
96309641 case GGML_OP_REPEAT:
@@ -9747,7 +9758,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97479758 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
97489759 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
97499760 if (node->op == GGML_OP_RMS_NORM) {
9750- ctx->do_add_rms_atomic = false;
9761+ ctx->do_add_rms_partials = false;
97519762 }
97529763 return false;
97539764 }
@@ -10663,9 +10674,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1066310674 vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
1066410675 }
1066510676
10666- ctx->prealloc_size_atomic_add = 0;
10667- ctx->prealloc_size_atomic_add_offset = 0;
10668- ctx->do_add_rms_atomic = false;
10677+ ctx->prealloc_size_add_rms_partials = 0;
10678+ ctx->prealloc_size_add_rms_partials_offset = 0;
10679+ ctx->do_add_rms_partials = false;
1066910680
1067010681 uint64_t total_mat_mul_bytes = 0;
1067110682 for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -10727,16 +10738,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1072710738 compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
1072810739 }
1072910740
10730- if (ctx->prealloc_size_atomic_add ) {
10741+ if (ctx->prealloc_size_add_rms_partials ) {
1073110742 if (ctx->compute_ctx.expired()) {
1073210743 compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
1073310744 ctx->compute_ctx = compute_ctx;
1073410745 ggml_vk_ctx_begin(ctx->device, compute_ctx);
1073510746 } else {
1073610747 compute_ctx = ctx->compute_ctx.lock();
1073710748 }
10738- // initialize atomic sums to zero.
10739- ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add , 0, 0, ctx->prealloc_size_atomic_add );
10749+ // initialize partial sums to zero.
10750+ ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials , 0, 0, ctx->prealloc_size_add_rms_partials );
1074010751 }
1074110752
1074210753 // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
0 commit comments