From 18703ad600cc68dbdb04d57434c876989a841d12 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sun, 4 Aug 2024 17:28:08 +0200 Subject: [PATCH] vulkan : implement Stable Diffusion operators (#904) * Fix Vulkan repeat op * Implement Vulkan concat op * Delete old Vulkan shader generator * Implement Vulkan im2col op * Implement Vulkan unary gelu_quick op * Implement Vulkan group_norm op * Implement Vulkan timestep_embedding op * Implement Vulkan upscale op * Fix Vulkan vk_context tensor extra index issue * Fix Vulkan matmul shader parameter bug * Properly fix Vulkan matmul shader parameter bug * Add Vulkan ADD f16 + f32 -> f16 operator support * Implement Vulkan tanh op * Fix Vulkan group count too large Validation error on non-Nvidia GPUs * Throw error when too much memory is requested * Fix another Vulkan group count too large Validation error on non-Nvidia GPUs * Fix matmul MMQ condition * Implement Vulkan pad op * Fix Vulkan crash when tensor is used multiple times in a compute graph * Add Vulkan CONCAT f16 + f16 -> f16 op * Add Vulkan LEAKY_RELU op --- src/ggml-vulkan.cpp | 840 +++-- src/ggml_vk_generate_shaders.py | 3149 ------------------- src/vulkan-shaders/add.comp | 6 +- src/vulkan-shaders/clamp.comp | 8 +- src/vulkan-shaders/concat.comp | 35 + src/vulkan-shaders/copy.comp | 8 +- src/vulkan-shaders/div.comp | 6 +- src/vulkan-shaders/gelu.comp | 2 +- src/vulkan-shaders/gelu_quick.comp | 23 + src/vulkan-shaders/generic_binary_head.comp | 6 +- src/vulkan-shaders/generic_unary_head.comp | 4 + src/vulkan-shaders/group_norm.comp | 66 + src/vulkan-shaders/im2col.comp | 57 + src/vulkan-shaders/leaky_relu.comp | 22 + src/vulkan-shaders/mul.comp | 6 +- src/vulkan-shaders/norm.comp | 2 +- src/vulkan-shaders/pad.comp | 26 + src/vulkan-shaders/relu.comp | 2 +- src/vulkan-shaders/rms_norm.comp | 2 +- src/vulkan-shaders/scale.comp | 6 +- src/vulkan-shaders/silu.comp | 2 +- src/vulkan-shaders/soft_max.comp | 2 +- src/vulkan-shaders/square.comp | 8 +- src/vulkan-shaders/sum_rows.comp | 2 +- src/vulkan-shaders/tanh.comp | 21 + src/vulkan-shaders/timestep_embedding.comp | 41 + src/vulkan-shaders/types.comp | 4 +- src/vulkan-shaders/upscale.comp | 36 + src/vulkan-shaders/vulkan-shaders-gen.cpp | 49 +- 29 files changed, 1029 insertions(+), 3412 deletions(-) delete mode 100644 src/ggml_vk_generate_shaders.py create mode 100644 src/vulkan-shaders/concat.comp create mode 100644 src/vulkan-shaders/gelu_quick.comp create mode 100644 src/vulkan-shaders/group_norm.comp create mode 100644 src/vulkan-shaders/im2col.comp create mode 100644 src/vulkan-shaders/leaky_relu.comp create mode 100644 src/vulkan-shaders/pad.comp create mode 100644 src/vulkan-shaders/tanh.comp create mode 100644 src/vulkan-shaders/timestep_embedding.comp create mode 100644 src/vulkan-shaders/upscale.comp diff --git a/src/ggml-vulkan.cpp b/src/ggml-vulkan.cpp index fa68360b9..d7fea78d0 100644 --- a/src/ggml-vulkan.cpp +++ b/src/ggml-vulkan.cpp @@ -177,24 +177,33 @@ struct vk_device_struct { vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16; vk_pipeline pipeline_mul_f32; vk_pipeline pipeline_div_f32; - vk_pipeline pipeline_add_f32; + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; + vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; vk_pipeline pipeline_clamp_f32; + vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; vk_pipeline pipeline_norm_f32; + vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_gelu_f32; + vk_pipeline pipeline_gelu_quick_f32; vk_pipeline pipeline_silu_f32; vk_pipeline pipeline_relu_f32; + vk_pipeline pipeline_leaky_relu_f32; + vk_pipeline pipeline_tanh_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; vk_pipeline pipeline_argsort_f32; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; + vk_pipeline pipeline_timestep_embedding_f32; std::vector pipelines; @@ -320,7 +329,7 @@ struct vk_op_binary_push_constants { uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23; uint32_t d_offset; - float param1; float param2; + float param1; float param2; int32_t param3; }; struct vk_op_diag_mask_push_constants { @@ -358,6 +367,25 @@ struct vk_op_argsort_push_constants { int32_t order; }; +struct vk_op_im2col_push_constants { + uint32_t batch_offset; uint32_t offset_delta; + uint32_t IC; + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t KW; uint32_t KH; + uint32_t pelements; + uint32_t CHW; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; + int32_t d0; int32_t d1; +}; + +struct vk_op_timestep_embedding_push_constants { + uint32_t nb1; + uint32_t dim; + uint32_t max_period; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -367,28 +395,32 @@ struct vk_staging_memcpy { size_t n; }; -struct vk_context { - size_t idx; +struct vk_op_upscale_push_constants { + uint32_t ne; uint32_t d_offset; + uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; + float sf0; float sf1; float sf2; float sf3; +}; +struct vk_context_struct { vk_submission * s; std::vector seqs; - ggml_tensor * exit_tensor; + int exit_tensor_idx; std::vector in_memcpys; std::vector out_memcpys; vk_queue * q; }; +typedef std::shared_ptr vk_context; +typedef std::weak_ptr vk_context_ref; struct ggml_tensor_extra_gpu { - size_t ctx_idx; - vk_buffer_ref buffer_gpu; uint64_t offset; void reset() { - ctx_idx = 0; buffer_gpu.reset(); offset = 0; } @@ -459,8 +491,10 @@ struct ggml_backend_vk_context { vk_buffer buffer_pool[MAX_VK_BUFFERS]; - vk_context * compute_ctx; - vk_context * transfer_ctx; + vk_context_ref compute_ctx; + vk_context_ref transfer_ctx; + + std::vector tensor_ctxs; }; #ifdef GGML_VULKAN_MEMORY_DEBUG @@ -510,12 +544,12 @@ static vk_instance_t vk_instance; static size_t vk_skip_checks; static size_t vk_output_tensor; -static void ggml_vk_print_tensor(ggml_backend * ctx, const ggml_tensor * tensor, const char * name); -static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor); -static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor); +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); +static void ggml_vk_check_results_0(ggml_tensor * tensor); +static void ggml_vk_check_results_1(ggml_tensor * tensor); #endif -typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); +typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend); @@ -708,11 +742,11 @@ static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, s return s; } -static void ggml_vk_submit(vk_context * ctx, vk::Fence fence) { - VK_LOG_DEBUG("ggml_vk_submit(" << ctx->seqs.size() << ", " << fence << ")"); +static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { if (ctx->seqs.empty()) { return; } + VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")"); std::vector> tl_wait_vals; std::vector> tl_signal_vals; @@ -844,21 +878,17 @@ static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_ q.stage_flags = stage_flags; } -static vk_context * ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) { - VK_LOG_DEBUG("ggml_vk_create_context()"); - ctx->gc.contexts.emplace_back(); - vk_context * result = &ctx->gc.contexts[ctx->gc.contexts.size() - 1]; - memset((void *) result, 0, sizeof(vk_context)); - result->idx = ctx->gc.contexts.size() - 1; +static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); + ctx->gc.contexts.emplace_back(result); result->q = &q; return result; } -static vk_context * ggml_vk_create_temporary_context(vk_queue& q) { - VK_LOG_DEBUG("ggml_vk_create_temporary_context()"); - vk_context * result = new vk_context; - memset((void *) result, 0, sizeof(vk_context)); - result->idx = 0; +static vk_context ggml_vk_create_temporary_context(vk_queue& q) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); result->q = &q; return result; } @@ -915,6 +945,10 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")"); + if (size > device->max_memory_allocation_size) { + throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); + } + std::lock_guard guard(device->mutex); vk_buffer buf = std::make_shared(); @@ -1027,7 +1061,7 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { return { buf, 0, VK_WHOLE_SIZE }; } -static void ggml_vk_sync_buffers(vk_context * ctx) { +static void ggml_vk_sync_buffers(vk_context& ctx) { VK_LOG_DEBUG("ggml_vk_sync_buffers()"); const std::vector mem_barriers{ { { vk::AccessFlagBits::eMemoryRead | vk::AccessFlagBits::eMemoryWrite }, { vk::AccessFlagBits::eMemoryRead | vk::AccessFlagBits::eMemoryWrite } } }; @@ -1041,7 +1075,7 @@ static void ggml_vk_sync_buffers(vk_context * ctx) { ); } -static void ggml_vk_wait_events(vk_context * ctx, std::vector&& events) { +static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) { VK_LOG_DEBUG("ggml_vk_wait_events()"); if (events.empty()) { return; @@ -1598,6 +1632,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -1605,20 +1640,31 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); @@ -1634,6 +1680,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); } static vk_device ggml_vk_get_device(size_t idx) { @@ -2077,9 +2128,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->staging_size = 0; ctx->staging_offset = 0; - ctx->compute_ctx = nullptr; - ctx->transfer_ctx = nullptr; - #ifdef GGML_VULKAN_CHECK_RESULTS const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); @@ -2112,7 +2160,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type } static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) { - VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline()"); + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_f32; } @@ -2126,7 +2174,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte return ctx->device->pipeline_matmul_f16; } - GGML_ASSERT(src1_type == GGML_TYPE_F32); + if (src1_type != GGML_TYPE_F32) { + return nullptr; + } switch (src0_type) { case GGML_TYPE_Q4_0: @@ -2370,7 +2420,7 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo return s; } -static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, std::vector&& buffers, size_t push_constant_size, const void* push_constants, std::array elements) { +static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, std::vector&& buffers, size_t push_constant_size, const void* push_constants, std::array elements) { const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); @@ -2410,7 +2460,7 @@ static void ggml_vk_end_submission(vk_submission& s, std::vector w s.signal_semaphores = std::move(signal_semaphores); } -static void ggml_vk_ctx_end(vk_context * ctx) { +static void ggml_vk_ctx_end(vk_context& ctx) { VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); if (ctx->s == nullptr) { return; @@ -2420,7 +2470,7 @@ static void ggml_vk_ctx_end(vk_context * ctx) { ctx->s = nullptr; } -static void ggml_vk_ctx_begin(vk_device& device, vk_context * subctx) { +static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")"); if (subctx->s != nullptr) { ggml_vk_ctx_end(subctx); @@ -2453,7 +2503,7 @@ static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { } } -static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) { +static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")"); GGML_ASSERT(!ggml_is_contiguous(tensor)); // Buffer is already mapped @@ -2558,7 +2608,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } } -static void ggml_vk_buffer_write_2d_async(vk_context * subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { +static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); // Buffer is already mapped if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { @@ -2623,7 +2673,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context * subctx, vk_buffer& dst, s } } -static void ggml_vk_buffer_write_async(vk_context * subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { +static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, staging_buffer, staging_offset, sync_staging); } @@ -2638,7 +2688,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); } } else { - vk_context * subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); ggml_vk_ctx_begin(dst->device, subctx); ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, nullptr, 0, true); ggml_vk_ctx_end(subctx); @@ -2650,8 +2700,6 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * ggml_vk_submit(subctx, dst->device->fence); VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); dst->device->device.resetFences({ dst->device->fence }); - - delete subctx; } } @@ -2660,12 +2708,14 @@ static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); } -static void ggml_vk_buffer_read_2d_async(vk_context * subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { +static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")"); GGML_ASSERT(width > 0); GGML_ASSERT(height > 0); GGML_ASSERT(src != nullptr); + // TODO: staging_offset is not used + // Check if dst is pinned memory vk_buffer buf = nullptr; size_t buf_offset; @@ -2714,18 +2764,18 @@ static void ggml_vk_buffer_read_2d_async(vk_context * subctx, vk_buffer& src, si deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); } -static void ggml_vk_buffer_read_async(vk_context * subctx, vk_buffer& src, size_t offset, void * dst, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { +static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) { return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, staging_buffer, staging_offset, sync_staging); } static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { - VK_LOG_DEBUG("ggml_vk_buffer_read(" << offset << ", " << size << ")"); + VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); memcpy(dst, (uint8_t *) src->ptr + offset, size); } else { - vk_context * subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); ggml_vk_ctx_begin(src->device, subctx); ggml_vk_buffer_read_async(subctx, src, offset, dst, size, nullptr, 0, true); ggml_vk_ctx_end(subctx); @@ -2737,12 +2787,10 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ for (auto& cpy : subctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - - delete subctx; } } -static void ggml_vk_buffer_copy_async(vk_context * ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { +static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); // Make sure both buffers are on same device GGML_ASSERT(src->device == dst->device); @@ -2756,15 +2804,13 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr if (src->device == dst->device) { VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); // Copy within the device - vk_context * subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); ggml_vk_ctx_begin(src->device, subctx); ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, src->device->fence); VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); src->device->device.resetFences({ src->device->fence }); - - delete subctx; } else { VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); // Copy device to device @@ -2783,7 +2829,7 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); - vk_context * subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); ggml_vk_ctx_begin(dst->device, subctx); subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); ggml_vk_ctx_end(subctx); @@ -2791,8 +2837,6 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz ggml_vk_submit(subctx, dst->device->fence); VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); dst->device->device.resetFences({ dst->device->fence }); - - delete subctx; } static uint32_t ggml_vk_guess_split_k(int m, int n, int k) { @@ -2855,7 +2899,7 @@ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ct } static void ggml_vk_matmul( - ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, @@ -2879,7 +2923,7 @@ static void ggml_vk_matmul( } static void ggml_vk_matmul_id( - ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, @@ -2916,7 +2960,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_ GGML_ABORT("fatal error"); } -static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) { +static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) { VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); const int tensor_type_size = ggml_type_size(tensor->type); @@ -2934,7 +2978,7 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 }); } -static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); @@ -3107,7 +3151,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su ); // NOLINT } -static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); @@ -3272,7 +3316,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); } -static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); @@ -3343,7 +3387,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); } -static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); @@ -3418,7 +3462,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); } -static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) { ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst); @@ -3431,7 +3475,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, } } -static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { +static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; @@ -3618,7 +3662,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * ); // NOLINT } -static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { +static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; @@ -3794,7 +3838,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z }); } -static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")"); if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst); @@ -3803,8 +3847,8 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subct } } -static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - // guaranteed to be an integer due to the check in ggml_can_repeat +static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_vk_op_repeat(" << src0 << ", " << src1 << ", " << dst << ")"); const uint64_t ne0 = dst->ne[0]; const uint64_t ne1 = dst->ne[1]; const uint64_t ne2 = dst->ne[2]; @@ -3825,6 +3869,7 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx const uint64_t nb02 = src0->nb[2]; const uint64_t nb03 = src0->nb[3]; + // guaranteed to be an integer due to the check in ggml_can_repeat const uint64_t nr0 = ne0/ne00; const uint64_t nr1 = ne1/ne01; const uint64_t nr2 = ne2/ne02; @@ -3852,8 +3897,8 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx for (uint64_t k1 = 0; k1 < ne01; k1++) { for (uint64_t i0 = 0; i0 < nr0; i0++) { copies.push_back({ - src_offset + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0, - dst_offset + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01, + src_offset + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01, + dst_offset + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0, ne00*nb0, }); } @@ -3874,11 +3919,6 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { switch (op) { - case GGML_OP_ADD: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_add_f32; - } - return nullptr; case GGML_OP_GET_ROWS: GGML_ASSERT(src1->type == GGML_TYPE_I32); if (dst->type == GGML_TYPE_F16) { @@ -3888,6 +3928,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_get_rows_f32[src0->type]; } return nullptr; + case GGML_OP_ADD: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_add_f32; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_add_f16_f32_f16; + } + return nullptr; case GGML_OP_MUL: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_mul_f32; @@ -3898,6 +3946,22 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_div_f32; } return nullptr; + case GGML_OP_CONCAT: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_concat_f32; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_concat_f16; + } + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_concat_i32; + } + return nullptr; + case GGML_OP_UPSCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_upscale_f32; + } + return nullptr; case GGML_OP_SCALE: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_scale_f32; @@ -3913,6 +3977,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_clamp_f32; } return nullptr; + case GGML_OP_PAD: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pad_f32; + } + return nullptr; case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: @@ -3922,6 +3991,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_norm_f32; } return nullptr; + case GGML_OP_GROUP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_group_norm_f32; + } + return nullptr; case GGML_OP_RMS_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_rms_norm_f32; @@ -3939,11 +4013,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_gelu_f32; } break; + case GGML_UNARY_OP_GELU_QUICK: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_gelu_quick_f32; + } + break; case GGML_UNARY_OP_RELU: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_relu_f32; } break; + case GGML_UNARY_OP_TANH: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_tanh_f32; + } + break; default: break; } @@ -3995,6 +4079,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_sum_rows_f32; } return nullptr; + case GGML_OP_IM2COL: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_f32_f16; + } + return nullptr; + case GGML_OP_TIMESTEP_EMBEDDING: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_timestep_embedding_f32; + } + return nullptr; + case GGML_OP_LEAKY_RELU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_leaky_relu_f32; + } + return nullptr; default: return nullptr; } @@ -4018,9 +4120,12 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_ADD: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_CLAMP: + case GGML_OP_PAD: return true; default: return false; @@ -4028,7 +4133,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { } template -static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) { +static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) { VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; if (src1 != nullptr) { std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -4124,7 +4229,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c vk_buffer d_D = extra->buffer_gpu.lock(); // Workaround for tiny tensor inputs on ROPE - if (use_src1 && y_sz > d_D->size) { + if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) { y_sz = VK_WHOLE_SIZE; } @@ -4173,13 +4278,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c if (op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) { ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1); - switch (dst->op) { + switch (op) { case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: - elements = { (uint32_t)ggml_nrows(src0), 1, 1 }; - break; + { + const uint32_t nr = ggml_nrows(src0); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } break; + case GGML_OP_GROUP_NORM: + { + const uint32_t num_groups = dst->op_params[0]; + elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; + } break; case GGML_OP_DIAG_MASK_INF: case GGML_OP_ROPE: elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; @@ -4190,6 +4308,49 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c case GGML_OP_ARGSORT: elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; break; + case GGML_OP_IM2COL: + { + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t batch = src1->ne[3]; + + elements = { OW * KW * KH, OH, batch * IC }; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + const uint32_t dim = dst->op_params[0]; + uint32_t half_ceil = (dim + 1) / 2; + elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; + } break; + case GGML_OP_ADD: + case GGML_OP_DIV: + case GGML_OP_MUL: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_UNARY: + { + const uint32_t ne = ggml_nelements(dst); + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + } break; default: elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; break; @@ -4216,7 +4377,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c if (use_src1) { subbuf_y = { d_Y, y_buf_offset, y_sz }; } else { - subbuf_y = { d_X, 0, d_X->size }; + subbuf_y = { d_X, 0, x_sz }; } ggml_vk_sync_buffers(subctx); @@ -4227,11 +4388,15 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c if (use_src2) { subbuf_z = { d_Z, z_buf_offset, z_sz }; } else { - subbuf_z = { d_X, 0, d_X->size }; + subbuf_z = { d_X, 0, x_sz }; } ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (op == GGML_OP_IM2COL) { + // im2col uses only src1 and dst buffers + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); } else if (use_src2) { ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_Z, z_buf_offset, z_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); @@ -4249,8 +4414,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, ne02 * ne03); - switch (dst->op) { + switch (op) { case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: elements = { (uint32_t)ne01, 1, 1 }; break; @@ -4286,11 +4452,11 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c } } -static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f }); +static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {}); } -static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4301,11 +4467,11 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - 0.0f, 0.0f, + 0.0f, 0.0f, 0, }); } -static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4316,11 +4482,11 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, cons (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - 0.0f, 0.0f, + 0.0f, 0.0f, 0, }); } -static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4331,11 +4497,11 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, cons (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - 0.0f, 0.0f, + 0.0f, 0.0f, 0, }); } -static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4346,11 +4512,44 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context * subctx, cons (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - 0.0f, 0.0f, + 0.0f, 0.0f, 0, }); } -static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + int * op_params = (int *)dst->op_params; + + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, op_params[0], + }); +} + +static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + + const float sf0 = (float)dst->ne[0] / src0->ne[0]; + const float sf1 = (float)dst->ne[1] / src0->ne[1]; + const float sf2 = (float)dst->ne[2] / src0->ne[2]; + const float sf3 = (float)dst->ne[3] / src0->ne[3]; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { + (uint32_t)ggml_nelements(dst), 0, + (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], + sf0, sf1, sf2, sf3, + }); +} + +static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4364,7 +4563,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, co }); } -static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4377,7 +4576,7 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, cons }); } -static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4391,7 +4590,20 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, co }); } -static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + }); +} + +static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -4406,27 +4618,37 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, cons }); } -static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); } -static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + int * op_params = (int *)dst->op_params; + + uint32_t num_groups = op_params[0]; + uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + static const float eps = 1e-6f; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }); +} + +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); } -static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); } -static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { int32_t * op_params = (int32_t *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }); } -static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; float scale = op_params[0]; @@ -4451,7 +4673,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, }); } -static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { +static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { const int n_dims = ((int32_t *) dst->op_params)[1]; // const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; @@ -4475,7 +4697,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con }); } -static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { int32_t * op_params = (int32_t *)dst->op_params; uint32_t ncols = src0->ne[0]; @@ -4494,10 +4716,59 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, }); } -static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }); } +static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const int32_t s0 = dst->op_params[0]; + const int32_t s1 = dst->op_params[1]; + const int32_t p0 = dst->op_params[2]; + const int32_t p1 = dst->op_params[3]; + const int32_t d0 = dst->op_params[4]; + const int32_t d1 = dst->op_params[5]; + + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t IW = src1->ne[0]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const uint32_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 + + const uint32_t pelements = OW * KW * KH; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { + batch_offset, offset_delta, + IC, IW, IH, OW, OH, KW, KH, + pelements, + IC * KH * KW, + s0, s1, p0, p1, d0, d1, + }); +} + +static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + const uint32_t dim = dst->op_params[0]; + const uint32_t max_period = dst->op_params[1]; + const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { + nb1, dim, max_period, + }); +} + +static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + const float * op_params = (const float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }); +} + #ifdef GGML_VULKAN_RUN_TESTS static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) { if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) { @@ -4686,7 +4957,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); - vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); for (size_t i = 0; i < num_it; i++) { ggml_vk_ctx_begin(ctx->device, subctx); ggml_vk_matmul( @@ -4894,7 +5165,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); - vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ggml_vk_ctx_begin(ctx->device, subctx); const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); @@ -5027,7 +5298,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); ggml_vk_buffer_write(y_buf, 0, y, y_sz); - vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); for (size_t i = 0; i < num_it; i++) { ggml_vk_ctx_begin(ctx->device, subctx); ggml_vk_matmul( @@ -5175,7 +5446,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm const bool y_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32 && !y_non_contig; - bool mmp = (use_src0 && use_src1 && src1_type == GGML_TYPE_F32) ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0_type, y_non_contig ? GGML_TYPE_F16 : src1->type) != nullptr : false; + bool mmp = (use_src0 && use_src1 && (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID)) ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type) != nullptr : false; const bool qx_needs_dequant = use_src0 && (!mmp || x_non_contig); const bool qy_needs_dequant = use_src1 && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig); @@ -5211,24 +5482,33 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_CLAMP: + case GGML_OP_PAD: case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: case GGML_OP_ARGSORT: case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_LEAKY_RELU: break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: break; default: return; @@ -5236,6 +5516,13 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + if ( + x_sz > ctx->device->max_memory_allocation_size || + y_sz > ctx->device->max_memory_allocation_size || + d_sz > ctx->device->max_memory_allocation_size || + split_k_size > ctx->device->max_memory_allocation_size) { + GGML_ABORT("Requested preallocation size is too large"); + } if (ctx->prealloc_size_x < x_sz) { ctx->prealloc_size_x = x_sz; } @@ -5430,7 +5717,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } } -static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, bool last_node){ +static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, bool last_node){ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra; if (ggml_is_empty(node) || extra == nullptr) { @@ -5457,7 +5744,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: break; default: return; @@ -5468,13 +5757,17 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ADD: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_CLAMP: + case GGML_OP_PAD: case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -5483,6 +5776,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_MUL_MAT_ID: case GGML_OP_ARGSORT: case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_LEAKY_RELU: break; default: std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; @@ -5490,102 +5786,137 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return; } - if (ctx->compute_ctx == nullptr) { - ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); - ggml_vk_ctx_begin(ctx->device, ctx->compute_ctx); + vk_context compute_ctx; + + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); } switch (node->op) { case GGML_OP_REPEAT: - ggml_vk_repeat(ctx, ctx->compute_ctx, src0, src1, node); + ggml_vk_repeat(ctx, compute_ctx, src0, node); break; case GGML_OP_GET_ROWS: - ggml_vk_get_rows(ctx, ctx->compute_ctx, src0, src1, node); + ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node); break; case GGML_OP_ADD: - ggml_vk_add(ctx, ctx->compute_ctx, src0, src1, node); + ggml_vk_add(ctx, compute_ctx, src0, src1, node); break; case GGML_OP_MUL: - ggml_vk_mul(ctx, ctx->compute_ctx, src0, src1, node); + ggml_vk_mul(ctx, compute_ctx, src0, src1, node); break; case GGML_OP_DIV: - ggml_vk_div(ctx, ctx->compute_ctx, src0, src1, node); + ggml_vk_div(ctx, compute_ctx, src0, src1, node); + + break; + case GGML_OP_CONCAT: + ggml_vk_concat(ctx, compute_ctx, src0, src1, node); + + break; + case GGML_OP_UPSCALE: + ggml_vk_upscale(ctx, compute_ctx, src0, node); break; case GGML_OP_SCALE: - ggml_vk_scale(ctx, ctx->compute_ctx, src0, node); + ggml_vk_scale(ctx, compute_ctx, src0, node); break; case GGML_OP_SQR: - ggml_vk_sqr(ctx, ctx->compute_ctx, src0, node); + ggml_vk_sqr(ctx, compute_ctx, src0, node); break; case GGML_OP_CLAMP: - ggml_vk_clamp(ctx, ctx->compute_ctx, src0, node); + ggml_vk_clamp(ctx, compute_ctx, src0, node); + + break; + case GGML_OP_PAD: + ggml_vk_pad(ctx, compute_ctx, src0, node); break; case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: - ggml_vk_cpy(ctx, ctx->compute_ctx, src0, node); + ggml_vk_cpy(ctx, compute_ctx, src0, node); break; case GGML_OP_NORM: - ggml_vk_norm(ctx, ctx->compute_ctx, src0, node); + ggml_vk_norm(ctx, compute_ctx, src0, node); + + break; + case GGML_OP_GROUP_NORM: + ggml_vk_group_norm(ctx, compute_ctx, src0, node); break; case GGML_OP_RMS_NORM: - ggml_vk_rms_norm(ctx, ctx->compute_ctx, src0, node); + ggml_vk_rms_norm(ctx, compute_ctx, src0, node); break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: - ggml_vk_unary(ctx, ctx->compute_ctx, src0, node); + case GGML_UNARY_OP_TANH: + ggml_vk_unary(ctx, compute_ctx, src0, node); break; default: return; } break; case GGML_OP_DIAG_MASK_INF: - ggml_vk_diag_mask_inf(ctx, ctx->compute_ctx, src0, node); + ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node); break; case GGML_OP_SOFT_MAX: - ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, node); + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node); break; case GGML_OP_ROPE: - ggml_vk_rope(ctx, ctx->compute_ctx, src0, src1, src2, node); + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node); break; case GGML_OP_ARGSORT: - ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node); + ggml_vk_argsort(ctx, compute_ctx, src0, node); break; case GGML_OP_SUM_ROWS: - ggml_vk_sum_rows(ctx, ctx->compute_ctx, src0, node); + ggml_vk_sum_rows(ctx, compute_ctx, src0, node); + + break; + case GGML_OP_IM2COL: + ggml_vk_im2col(ctx, compute_ctx, src0, src1, node); + + break; + case GGML_OP_TIMESTEP_EMBEDDING: + ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node); + + break; + case GGML_OP_LEAKY_RELU: + ggml_vk_leaky_relu(ctx, compute_ctx, src0, node); break; case GGML_OP_MUL_MAT: - ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node); + ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node); break; case GGML_OP_MUL_MAT_ID: - ggml_vk_mul_mat_id(ctx, ctx->compute_ctx, src0, src1, src2, node); + ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node); break; default: return; } - extra->ctx_idx = ctx->compute_ctx->idx; + ctx->tensor_ctxs[node_idx] = compute_ctx; #ifdef GGML_VULKAN_CHECK_RESULTS // Force context reset on each node so that each tensor ends up in its own context @@ -5594,13 +5925,13 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod #endif if (last_node) { - ggml_vk_ctx_end(ctx->compute_ctx); - ctx->compute_ctx->exit_tensor = node; - ctx->compute_ctx = nullptr; + ggml_vk_ctx_end(compute_ctx); + compute_ctx->exit_tensor_idx = node_idx; + ctx->compute_ctx.reset(); } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor){ +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx){ ggml_tensor_extra_gpu * extra = nullptr; switch (tensor->op) { @@ -5608,13 +5939,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_GET_ROWS: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_CLAMP: + case GGML_OP_PAD: case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -5626,6 +5961,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_NONE: case GGML_OP_ARGSORT: case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_LEAKY_RELU: + case GGML_OP_REPEAT: extra = (ggml_tensor_extra_gpu *) tensor->extra; break; @@ -5633,7 +5972,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: extra = (ggml_tensor_extra_gpu *) tensor->extra; break; default: @@ -5656,31 +5997,31 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); #ifdef GGML_VULKAN_CHECK_RESULTS - ggml_vk_check_results_0(ctx, tensor); + ggml_vk_check_results_0(tensor); #endif - vk_context& subctx = ctx->gc.contexts[extra->ctx_idx]; + vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); // Only run if ctx hasn't been submitted yet - if (!subctx.seqs.empty()) { + if (!subctx->seqs.empty()) { // Do staging buffer copies - for (auto& cpy : subctx.in_memcpys) { + for (auto& cpy : subctx->in_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(&subctx, ctx->fence); + ggml_vk_submit(subctx, ctx->fence); } - if (tensor == subctx.exit_tensor) { + if (tensor_idx == subctx->exit_tensor_idx) { VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); ctx->device->device.resetFences({ ctx->fence }); // Do staging buffer copies - for (auto& cpy : subctx.out_memcpys) { + for (auto& cpy : subctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - subctx.in_memcpys.clear(); - subctx.out_memcpys.clear(); + subctx->in_memcpys.clear(); + subctx->out_memcpys.clear(); } return true; @@ -5725,8 +6066,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->staging_offset = 0; - ctx->compute_ctx = nullptr; - ctx->transfer_ctx = nullptr; + ctx->tensor_ctxs.clear(); ctx->gc.contexts.clear(); } @@ -6063,15 +6403,20 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; - if (ctx->transfer_ctx == nullptr) { + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); - ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx); + transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); } vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_write_async(ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset); + ggml_vk_buffer_write_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset); } GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -6081,15 +6426,20 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; - if (ctx->transfer_ctx == nullptr) { + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); - ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx); + transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); } vk_buffer buf = extra->buffer_gpu.lock(); - ggml_vk_buffer_read_async(ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset); + ggml_vk_buffer_read_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset); } GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { @@ -6099,16 +6449,21 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra; ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - if (ctx->transfer_ctx == nullptr) { + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); - ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx); + transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); } vk_buffer src_buf = src_extra->buffer_gpu.lock(); vk_buffer dst_buf = dst_extra->buffer_gpu.lock(); - ggml_vk_buffer_copy_async(ctx->transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src)); + ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src)); return true; } @@ -6118,25 +6473,27 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) { VK_LOG_DEBUG("ggml_backend_vk_synchronize()"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; - if(ctx->transfer_ctx == nullptr) { + if(ctx->transfer_ctx.expired()) { return; } - ggml_vk_ctx_end(ctx->transfer_ctx); + vk_context transfer_ctx = ctx->transfer_ctx.lock(); + + ggml_vk_ctx_end(transfer_ctx); - for (auto& cpy : ctx->transfer_ctx->in_memcpys) { + for (auto& cpy : transfer_ctx->in_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(ctx->transfer_ctx, ctx->fence); + ggml_vk_submit(transfer_ctx, ctx->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences"); ctx->device->device.resetFences({ ctx->fence }); - for (auto& cpy : ctx->transfer_ctx->out_memcpys) { + for (auto& cpy : transfer_ctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ctx->transfer_ctx = nullptr; + ctx->transfer_ctx.reset(); } static bool ggml_vk_is_empty(ggml_tensor * node) { @@ -6159,8 +6516,11 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen last_node -= 1; } + // Reserve tensor context space for all nodes + ctx->tensor_ctxs.resize(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx,cgraph->nodes[i], i == last_node); + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, i == last_node); } for (int i = 0; i < cgraph->n_nodes; i++) { @@ -6170,13 +6530,17 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen continue; } - bool ok = ggml_vk_compute_forward(ctx, node); + bool ok = ggml_vk_compute_forward(ctx, node, i); if (!ok) { - fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + if (node->op == GGML_OP_UNARY) { + std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } else { + std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; + } } #ifdef GGML_VULKAN_CHECK_RESULTS else { - ggml_vk_check_results_1(ctx, node); + ggml_vk_check_results_1(node); } #endif GGML_ASSERT(ok); @@ -6196,8 +6560,10 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: return ggml_is_contiguous(op->src[0]); default: return false; @@ -6270,11 +6636,11 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const } return false; } break; - // case GGML_OP_REPEAT: - // { - // ggml_type src0_type = op->src[0]->type; - // return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; - // } break; + case GGML_OP_REPEAT: + { + ggml_type src0_type = op->src[0]->type; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + } break; case GGML_OP_ROPE: return ggml_is_contiguous(op->src[0]); case GGML_OP_NONE: @@ -6283,18 +6649,25 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: case GGML_OP_ADD: case GGML_OP_MUL: case GGML_OP_DIV: - case GGML_OP_RMS_NORM: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_CLAMP: + case GGML_OP_PAD: case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ARGSORT: case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_LEAKY_RELU: return true; default: return false; @@ -6509,10 +6882,12 @@ static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * d } } -static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tensor * tensor, const char * name) { +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) { void * tensor_data = tensor->data; - if (ggml_backend_buffer_is_vk(tensor->buffer)) { + const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer); + + if (is_gpu) { const size_t tensor_size = ggml_nbytes(tensor); tensor_data = malloc(tensor_size); @@ -6533,13 +6908,10 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); std::cerr << std::endl; - std::cerr << std::endl << "Result:" << std::endl; - ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 1, 0); - std::cerr << std::endl; std::vector done; ggml_vk_print_graph_origin(tensor, done); - if (ggml_backend_buffer_is_vk(tensor->buffer)) { + if (is_gpu) { free(tensor_data); } } @@ -6548,8 +6920,8 @@ void * comp_result; size_t comp_size; size_t comp_nb[GGML_MAX_DIMS]; size_t check_counter = 0; -static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor) { - if (tensor->op == GGML_OP_TRANSPOSE) { +static void ggml_vk_check_results_0(ggml_tensor * tensor) { + if (tensor->op == GGML_OP_TRANSPOSE) { return; } @@ -6565,7 +6937,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * ggml_tensor * src2 = tensor->src[2]; struct ggml_init_params iparams = { - /*.mem_size =*/ 1024*1024*1024, + /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, /*.mem_buffer =*/ NULL, /*.no_alloc =*/ false, }; @@ -6624,7 +6996,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * } if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(ctx, src0, "src0"); + ggml_vk_print_tensor(src0, "src0"); } } if (src1 != nullptr) { @@ -6666,23 +7038,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * } if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(ctx, src1, "src1"); - std::cerr << "TENSOR CHECK: " << ggml_op_name(src1_clone->op) << " (check " << check_counter << ")" << std::endl; - std::cerr << "src1_clone=" << tensor << " src1_clone->type: " << ggml_type_name(src1_clone->type) << " ne0=" << src1_clone->ne[0] << " nb0=" << src1_clone->nb[0] << " ne1=" << src1_clone->ne[1] << " nb1=" << src1_clone->nb[1] << " ne2=" << src1_clone->ne[2] << " nb2=" << src1_clone->nb[2] << " ne3=" << src1_clone->ne[3] << " nb3=" << src1_clone->nb[3] << std::endl; - if (src1->src[0] != nullptr) { - std::cerr << "src1->src[0]=" << src1->src[0] << " op=" << ggml_op_name(src1->src[0]->op) << " type=" << ggml_type_name(src1->src[0]->type) << " ne0=" << src1->src[0]->ne[0] << " nb0=" << src1->src[0]->nb[0] << " ne1=" << src1->src[0]->ne[1] << " nb1=" << src1->src[0]->nb[1] << " ne2=" << src1->src[0]->ne[2] << " nb2=" << src1->src[0]->nb[2] << " ne3=" << src1->src[0]->ne[3] << " nb3=" << src1->src[0]->nb[3] << std::endl; - } - if (src1->src[1] != nullptr) { - std::cerr << "src1->src[1]=" << src1->src[1] << " op=" << ggml_op_name(src1->src[1]->op) << " type=" << ggml_type_name(src1->src[1]->type) << " ne0=" << src1->src[1]->ne[0] << " nb0=" << src1->src[1]->nb[0] << " ne1=" << src1->src[1]->ne[1] << " nb1=" << src1->src[1]->nb[1] << " ne2=" << src1->src[1]->ne[2] << " nb2=" << src1->src[1]->nb[2] << " ne3=" << src1->src[1]->ne[3] << " nb3=" << src1->src[1]->nb[3] << std::endl; - } - std::cerr << std::endl << "Result:" << std::endl; - ggml_vk_print_tensor_area(src1_clone, src1_clone->data, 5, 5, 0, 0); - std::cerr << std::endl; - std::cerr << std::endl << "Result:" << std::endl; - ggml_vk_print_tensor_area(src1_clone, src1_clone->data, 5, 5, 1, 0); - std::cerr << std::endl; - std::vector done; - ggml_vk_print_graph_origin(src1_clone, done); + ggml_vk_print_tensor(src1, "src1"); } } if (src2 != nullptr) { @@ -6724,23 +7080,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * } if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(ctx, src2, "src2"); - std::cerr << "TENSOR CHECK: " << ggml_op_name(src2_clone->op) << " (check " << check_counter << ")" << std::endl; - std::cerr << "src2_clone=" << tensor << " src2_clone->type: " << ggml_type_name(src2_clone->type) << " ne0=" << src2_clone->ne[0] << " nb0=" << src2_clone->nb[0] << " ne1=" << src2_clone->ne[1] << " nb1=" << src2_clone->nb[1] << " ne2=" << src2_clone->ne[2] << " nb2=" << src2_clone->nb[2] << " ne3=" << src2_clone->ne[3] << " nb3=" << src2_clone->nb[3] << std::endl; - if (src2->src[0] != nullptr) { - std::cerr << "src2->src[0]=" << src2->src[0] << " op=" << ggml_op_name(src2->src[0]->op) << " type=" << ggml_type_name(src2->src[0]->type) << " ne0=" << src2->src[0]->ne[0] << " nb0=" << src2->src[0]->nb[0] << " ne1=" << src2->src[0]->ne[1] << " nb1=" << src2->src[0]->nb[1] << " ne2=" << src2->src[0]->ne[2] << " nb2=" << src2->src[0]->nb[2] << " ne3=" << src2->src[0]->ne[3] << " nb3=" << src2->src[0]->nb[3] << std::endl; - } - if (src2->src[1] != nullptr) { - std::cerr << "src2->src[1]=" << src2->src[1] << " op=" << ggml_op_name(src2->src[1]->op) << " type=" << ggml_type_name(src2->src[1]->type) << " ne0=" << src2->src[1]->ne[0] << " nb0=" << src2->src[1]->nb[0] << " ne1=" << src2->src[1]->ne[1] << " nb1=" << src2->src[1]->nb[1] << " ne2=" << src2->src[1]->ne[2] << " nb2=" << src2->src[1]->nb[2] << " ne3=" << src2->src[1]->ne[3] << " nb3=" << src2->src[1]->nb[3] << std::endl; - } - std::cerr << std::endl << "Result:" << std::endl; - ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 0, 0); - std::cerr << std::endl; - std::cerr << std::endl << "Result:" << std::endl; - ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 1, 0); - std::cerr << std::endl; - std::vector done; - ggml_vk_print_graph_origin(src2_clone, done); + ggml_vk_print_tensor(src2, "src2"); } } @@ -6752,16 +7092,24 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone); } else if (tensor->op == GGML_OP_DIV) { tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_CONCAT) { + tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_UPSCALE) { + tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_SCALE) { tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src0_clone); } else if (tensor->op == GGML_OP_CLAMP) { tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + } else if (tensor->op == GGML_OP_PAD) { + tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]); } else if (tensor->op == GGML_OP_ADD) { tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_GROUP_NORM) { + tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_RMS_NORM) { tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_SOFT_MAX) { @@ -6777,12 +7125,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * const int mode = ((int32_t *) tensor->op_params)[2]; //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; - float freq_base = ((float *) tensor->op_params)[5]; - float freq_scale = ((float *) tensor->op_params)[6]; - float ext_factor = ((float *) tensor->op_params)[7]; - float attn_factor = ((float *) tensor->op_params)[8]; - float beta_fast = ((float *) tensor->op_params)[9]; - float beta_slow = ((float *) tensor->op_params)[10]; + const float freq_base = ((float *) tensor->op_params)[5]; + const float freq_scale = ((float *) tensor->op_params)[6]; + const float ext_factor = ((float *) tensor->op_params)[7]; + const float attn_factor = ((float *) tensor->op_params)[8]; + const float beta_fast = ((float *) tensor->op_params)[9]; + const float beta_slow = ((float *) tensor->op_params)[10]; tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); } else if (tensor->op == GGML_OP_UNARY) { switch (ggml_get_unary_op(tensor)) { @@ -6792,9 +7140,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_UNARY_OP_GELU: tensor_clone = ggml_gelu(ggml_ctx, src0_clone); break; + case GGML_UNARY_OP_GELU_QUICK: + tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone); + break; case GGML_UNARY_OP_RELU: tensor_clone = ggml_relu(ggml_ctx, src0_clone); break; + case GGML_UNARY_OP_TANH: + tensor_clone = ggml_tanh(ggml_ctx, src0_clone); + break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); @@ -6823,6 +7177,23 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_SUM_ROWS) { tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_IM2COL) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + + const bool is_2D = tensor->op_params[6] == 1; + tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { + const int32_t dim = tensor->op_params[0]; + const int32_t max_period = tensor->op_params[1]; + tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); + } else if (tensor->op == GGML_OP_LEAKY_RELU) { + const float * op_params = (const float *)tensor->op_params; + tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); @@ -6834,7 +7205,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(ctx, tensor_clone, "tensor_clone"); + ggml_vk_print_tensor(tensor_clone, "tensor_clone"); } comp_size = ggml_nbytes(tensor_clone); @@ -6851,9 +7222,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * } ggml_free(ggml_ctx); + + VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); } -static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor) { +static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (tensor->op == GGML_OP_TRANSPOSE) { return; } @@ -6977,11 +7350,6 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * std::cerr << std::endl << "Correct:" << std::endl; ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0); std::cerr << std::endl; - std::cerr << std::endl << "Result:" << std::endl; - ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 1, 0); - std::cerr << std::endl << "Correct:" << std::endl; - ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 1, 0); - std::cerr << std::endl; std::vector done; ggml_vk_print_graph_origin(tensor, done); } @@ -7018,5 +7386,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * if (ggml_backend_buffer_is_vk(tensor->buffer)) { free(tensor_data); } + + VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")"); } #endif diff --git a/src/ggml_vk_generate_shaders.py b/src/ggml_vk_generate_shaders.py deleted file mode 100644 index 7ebf6003f..000000000 --- a/src/ggml_vk_generate_shaders.py +++ /dev/null @@ -1,3149 +0,0 @@ -#!/usr/bin/env python - -import logging -import argparse -import asyncio -import os -import sys -from tempfile import gettempdir, NamedTemporaryFile - -logger = logging.getLogger("ggml-vk-generate-shaders") - -shader_f32 = """ -#define FLOAT_TYPE float -""" -shader_f16 = """ -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#define FLOAT_TYPE float16_t -""" -shader_int8_ext = """ -#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require -""" - -# Type-specific defines -shader_f32_defines = """ -#define QUANT_K 1 -#define QUANT_R 1 - -#define A_TYPE float -""" -shader_f16_defines = """ -#define QUANT_K 1 -#define QUANT_R 1 - -#define A_TYPE float16_t -""" -shader_q4_0_defines = """ -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_q4_0 -{ - float16_t d; - uint8_t qs[16]; -}; - -#define A_TYPE block_q4_0 -""" -shader_q4_1_defines = """ -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_q4_1 -{ - float16_t d; - float16_t m; - uint8_t qs[16]; -}; - -#define A_TYPE block_q4_1 -""" -shader_q5_0_defines = """ -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_q5_0 -{ - float16_t d; - uint16_t qh[2]; - uint8_t qs[16]; -}; - -#define A_TYPE block_q5_0 -""" -shader_q5_1_defines = """ -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_q5_1 -{ - float16_t d; - float16_t m; - uint qh; - uint8_t qs[16]; -}; - -#define A_TYPE block_q5_1 -""" -shader_q8_0_defines = """ -#define QUANT_K 32 -#define QUANT_R 1 - -struct block_q8_0 -{ - float16_t d; - int8_t qs[32]; -}; - -#define A_TYPE block_q8_0 -""" - -# K-quants -shader_q2_K_defines = """ -#define QUANT_K 256 - -struct block_q2_K -{ - uint8_t scales[QUANT_K/16]; - uint8_t qs[QUANT_K/4]; - f16vec2 d; -}; - -#define A_TYPE block_q2_K -""" -shader_q3_K_defines = """ -#define QUANT_K 256 - -struct block_q3_K -{ - uint8_t hmask[QUANT_K/8]; - uint8_t qs[QUANT_K/4]; - uint8_t scales[12]; - float16_t d; -}; - -#define A_TYPE block_q3_K -""" -shader_q4_K_defines = """ -#define QUANT_K 256 - -struct block_q4_K -{ - f16vec2 d; - uint8_t scales[3*QUANT_K/64]; - uint8_t qs[QUANT_K/2]; -}; - -#define A_TYPE block_q4_K -""" -shader_q5_K_defines = """ -#define QUANT_K 256 - -struct block_q5_K -{ - f16vec2 d; - uint8_t scales[12]; - uint8_t qh[QUANT_K/8]; - uint8_t qs[QUANT_K/2]; -}; - -#define A_TYPE block_q5_K -""" -shader_q6_K_defines = """ -#define QUANT_K 256 - -struct block_q6_K -{ - uint8_t ql[QUANT_K/2]; - uint8_t qh[QUANT_K/4]; - int8_t scales[QUANT_K/16]; - float16_t d; -}; - -#define A_TYPE block_q6_K -""" - -# Dequant functions -shader_float_dequant_func = """ -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); -} -""" - -shader_q4_0_dequant_func = """ -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; -} -""" - -shader_q4_1_dequant_func = """ -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - const float m = float(data_a[a_offset + ib].m); - const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return vec2(vui & 0xF, vui >> 4) * d + m; -} -""" - -shader_q5_0_dequant_func = """ -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; -} -""" - -shader_q5_1_dequant_func = """ -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - const float m = float(data_a[a_offset + ib].m); - const uint uint_qh = data_a[a_offset + ib].qh; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; -} -""" - -shader_q8_0_dequant_func = """ -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d; -} -""" - -# MULMAT - -mulmat_head = """#version 450 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require - -#ifdef MUL_MAT_ID -#extension GL_EXT_buffer_reference2 : require -#extension GL_EXT_nonuniform_qualifier : require -#extension GL_EXT_scalar_block_layout : require -#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require - -#define EXPERT_COUNT 8 -#endif - -#ifndef LOAD_VEC_A -#define LOAD_VEC_A 1 -#endif -#ifndef LOAD_VEC_B -#define LOAD_VEC_B 1 -#endif -""" - -mulmat_body1 = """ -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; - -#ifdef MUL_MAT_ID -layout (binding = 3) readonly buffer IDS {int data_ids[];}; -#endif - -layout (push_constant) uniform parameter -{ - uint M; - uint N; - uint K; - uint stride_a; - uint stride_b; - uint stride_d; - uint k_split; - - uint ne02; - uint ne12; - uint broadcast2; - uint broadcast3; - - uint batch_stride_a; - uint batch_stride_b; - uint batch_stride_d; - -#ifdef MUL_MAT_ID - uint expert_stride_a; - uint expert_stride_b0; - uint expert_stride_b1; - uint expert_stride_d; - - uint ids_stride; - - uint n_as; - uint nei0; - uint nei1; - uint nbi1; - uint ne11; -#endif -} p; - -layout (constant_id = 1) const uint BM = 64; -layout (constant_id = 2) const uint BN = 64; -layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant -layout (constant_id = 4) const uint WM = 32; -layout (constant_id = 5) const uint WN = 32; -layout (constant_id = 6) const uint WMITER = 2; -layout (constant_id = 7) const uint TM = 4; -layout (constant_id = 8) const uint TN = 2; -layout (constant_id = 9) const uint WARP = 32; - -shared FLOAT_TYPE buf_a[BM * (BK+1)]; -shared FLOAT_TYPE buf_b[BN * (BK+1)]; - -#ifdef MUL_MAT_ID -shared u8vec2 rowids[2048]; -#endif - -void main() { -#ifdef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z / p.n_as; - const uint expert_idx = gl_GlobalInvocationID.z % p.n_as; -#else - const uint batch_idx = gl_GlobalInvocationID.z; -#endif - - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; - - const uint blocks_m = (p.M + BM - 1) / BM; - const uint ir = gl_WorkGroupID.x % blocks_m; - const uint ik = gl_WorkGroupID.x / blocks_m; - const uint ic = gl_WorkGroupID.y; - - const uint warp_i = gl_LocalInvocationID.x / WARP; - const uint warp_r = warp_i % (BM / WM); - const uint warp_c = warp_i / (BM / WM); - - const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); - const uint WSUBM = WM / WMITER; - const uint WSUBN = WN / WNITER; - - const uint tiw = gl_LocalInvocationID.x % WARP; - const uint tiwr = tiw % (WSUBM / TM); - const uint tiwc = tiw / (WSUBM / TM); - - const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); - const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); - const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); - const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); - - const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK; - const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; - -#ifdef MUL_MAT_ID - uint _ne1 = 0; - for (uint ii1 = 0; ii1 < p.nei1; ii1++) { - for (uint ii0 = 0; ii0 < p.nei0; ii0++) { - if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { - rowids[_ne1] = u8vec2(ii0, ii1); - _ne1++; - } - } - } - - const u8vec2 id = rowids[ir * BN + ic]; -#endif - - const uint start_k = ik * p.k_split; - const uint end_k = min(p.K, (ik + 1) * p.k_split); - - uint pos_a = ( -#ifdef MUL_MAT_ID - expert_idx * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; - uint pos_b = ( -#ifdef MUL_MAT_ID - id.y * p.expert_stride_b1 + - (id.x % p.ne11) * p.expert_stride_b0 + -#endif - batch_idx * p.batch_stride_b + - ic * BN * p.stride_b + start_k) / LOAD_VEC_B; - - float sums[WMITER * TM * WNITER * TN]; - FLOAT_TYPE cache_a[WMITER * TM]; - FLOAT_TYPE cache_b[WNITER * TN]; - - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; - } - - [[unroll]] for (uint block = start_k; block < end_k; block += BK) { - [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {""" - -mulmat_load_scalar = """ -#if LOAD_VEC_A == 8 - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); - buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); - buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); - buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); - buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); -#elif LOAD_VEC_A == 4 - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); -#else - if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { - buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); - } else { - buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f); - } -#endif -""" - -mulmat_load_q4_0 = """ - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; - - const uint ib = idx / 16; - const uint iqs = idx & 0xF; - - const float d = float(data_a[ib].d); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);""" - -mulmat_load_q4_1 = """ - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; - - const uint ib = idx / 16; - const uint iqs = idx & 0xF; - - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);""" - -mulmat_load_q5_0 = """ - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; - - const uint ib = idx / 16; - const uint iqs = idx & 0xF; - - const float d = float(data_a[ib].d); - const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);""" - -mulmat_load_q5_1 = """ - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; - - const uint ib = idx / 16; - const uint iqs = idx & 0xF; - - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint uint_qh = data_a[ib].qh; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);""" - -mulmat_load_q8_0 = """ - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 16; - const uint iqs = (idx & 0xF) * 2; - - const float d = float(data_a[ib].d); - const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);""" - - -mulmat_load_q2_K = """ - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 - const uint scalesi = iqs / 8; // 0..15 - const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - - const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); - const uint scales = data_a[ib].scales[scalesi]; - const vec2 d = vec2(data_a[ib].d); - - const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);""" - -mulmat_load_q3_K = """ - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 64; // 0,1 - const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 - const uint hmi = (iqs % 16) * 2; // 0,2,4..30 - const uint j = (iqs % 64) / 4; // 0..3 - const uint is = iqs / 8; // 0..15 - const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 - const uint qsshift = halfsplit * 2; // 0,2,4,6 - const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 - - const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) : - is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) : - is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) : - (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4)); - const float dl = float(data_a[ib].d) * float(us - 32); - - buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); - buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));""" - -mulmat_load_q4_K = """ - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 32; // 0,1,2,3 - const uint b = (iqs % 32) / 16; // 0,1 - const uint is = 2 * n + b; // 0..7 - const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - - const vec2 loadd = vec2(data_a[ib].d); - - uint8_t sc; - uint8_t mbyte; - if (is < 4) { - sc = uint8_t(data_a[ib].scales[is ] & 63); - mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); - } else { - sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); - mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); - } - const float d = loadd.x * sc; - const float m = loadd.y * mbyte; - - buf_a[buf_idx ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) - m); - buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);""" - -mulmat_load_q5_K = """ - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 32; // 0,1,2,3 - const uint b = (iqs % 32) / 16; // 0,1 - const uint is = 2 * n + b; // 0..7 - const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - const uint qhi = (iqs % 16) * 2; // 0,2,4..30 - - const uint8_t hm = uint8_t(1 << (iqs / 16)); - - const vec2 loadd = vec2(data_a[ib].d); - - uint8_t sc; - uint8_t mbyte; - if (is < 4) { - sc = uint8_t(data_a[ib].scales[is ] & 63); - mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); - } else { - sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); - mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); - } - const float d = loadd.x * sc; - const float m = loadd.y * mbyte; - - buf_a[buf_idx ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m); - buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);""" - -mulmat_load_q6_K = """ - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 64; // 0,1 - const uint b = (iqs % 64) / 32; // 0,1 - const uint is_b = (iqs % 16) / 8; // 0,1 - const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - const uint is = 8 * n + qhshift + is_b; // 0..15 - const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 - const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 - - const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); - - buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); - buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));""" - -mulmat_body2 = """ - } - [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { -#if LOAD_VEC_B == 8 - const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; - const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); - buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); - buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); - buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); - buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); -#elif LOAD_VEC_B == 4 - const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; - const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); -#else - if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { - buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); - } else { - buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); - } -#endif - } - - barrier(); - - pos_a += BK / LOAD_VEC_A; - pos_b += BK / LOAD_VEC_B; - - for (uint i = 0; i < BK; i++) { - // Load from shared into cache - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]); - } - } - } - } - } - - barrier(); - } - - const uint dr = ir * BM + warp_r * WM; - const uint dc = ic * BN + warp_c * WN; - - const uint offsets = -#ifdef MUL_MAT_ID - expert_idx * p.expert_stride_d + -#endif - batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; - - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - - const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; - const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); - } - } - } - } - } -} -""" - -mulmat_split_k_reduce_src = """#version 450 - -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {float data_a[];}; -layout (binding = 1) writeonly buffer D {float data_d[];}; - -layout (push_constant) uniform parameter { - uint ne; - uint k_num; -} p; - -void main() { - const uint idx = gl_GlobalInvocationID.x; - - if (idx >= p.ne) { - return; - } - - float result = 0.0f; - - [[unroll]] for (uint i = 0; i < p.k_num; i++) { - result += data_a[i * p.ne + idx]; - } - - data_d[idx] = result; -} -""" - -# DEQUANT SHADER -dequant_head = """#version 450 - -#extension GL_EXT_control_flow_attributes : require -#extension GL_EXT_shader_16bit_storage : require - -layout (push_constant) uniform parameter -{ - uint M; - uint K; - uint stride_a; - uint stride_b; - uint nel; -} p; -""" - -dequant_f32_body = """ -layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {float data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - const uint i = gl_GlobalInvocationID.x * 16; - - if (i >= p.nel) { - return; - } - - [[unroll]] for (uint l = 0; l < 16; l++) { - data_b[i + l] = D_TYPE(data_a[i + l]); - } -} -""" - -dequant_q4_0_body = """ -layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {block_q4_0 data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; - - const uint tid = gl_LocalInvocationID.x % 64; - const uint il = tid/32; - const uint ir = tid%32; - const uint ib = 32*i + ir; - if (ib >= p.nel / 32) { - return; - } - - const uint b_idx = 1024*i + 32*ir + 8*il; - - const float d = float(data_a[ib].d); - const float dm = -8.0f * d; - - const uint q_idx = 8*il; - - [[unroll]] for (uint l = 0; l < 8; ++l) { - data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + dm); - data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + dm); - } -} -""" - -dequant_q4_1_body = """ -layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {block_q4_1 data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; - - const uint tid = gl_LocalInvocationID.x % 64; - const uint il = tid/32; - const uint ir = tid%32; - const uint ib = 32*i + ir; - if (ib >= p.nel / 32) { - return; - } - - const uint b_idx = 1024*i + 32*ir + 8*il; - - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - - const uint q_idx = 8*il; - - [[unroll]] for (uint l = 0; l < 8; ++l) { - data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m); - data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m); - } -} -""" - -dequant_q5_0_body = """ -layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {block_q5_0 data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; - - const uint tid = gl_LocalInvocationID.x % 64; - const uint il = tid/32; - const uint ir = tid%32; - const uint ib = 32*i + ir; - if (ib >= p.nel / 32) { - return; - } - - const uint b_idx = 1024*i + 32*ir + 8*il; - - const float d = float(data_a[ib].d); - const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; - - const uint q_idx = 8*il; - - [[unroll]] for (uint l = 0; l < 8; ++l) { - const uint iqs = q_idx + l; - const uint vui = uint(data_a[ib].qs[iqs]); - data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f)); - data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f)); - } -} -""" - -dequant_q5_1_body = """ -layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {block_q5_1 data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; - - const uint tid = gl_LocalInvocationID.x % 64; - const uint il = tid/32; - const uint ir = tid%32; - const uint ib = 32*i + ir; - if (ib >= p.nel / 32) { - return; - } - - const uint b_idx = 1024*i + 32*ir + 8*il; - - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint qh = data_a[ib].qh; - - const uint q_idx = 8*il; - - [[unroll]] for (uint l = 0; l < 8; ++l) { - const uint iqs = q_idx + l; - const uint vui = uint(data_a[ib].qs[iqs]); - data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m); - data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m); - } -} -""" - -dequant_q8_0_body = """ -layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {block_q8_0 data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; - - const uint tid = gl_LocalInvocationID.x % 64; - const uint il = tid/32; - const uint ir = tid%32; - const uint ib = 32*i + ir; - if (ib >= p.nel / 32) { - return; - } - - const uint b_idx = 1024*i + 32*ir + 16*il; - - const float d = float(data_a[ib].d); - - const uint q_idx = 16*il; - - [[unroll]] for (uint l = 0; l < 16; l += 2) { - data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]); - data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]); - } -} -""" - -# K-quants -dequant_q2_K_body = """ -layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { - const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { - return; - } - - const uint tid = gl_LocalInvocationID.x; - const uint ip = tid / 32; - const uint il = tid - 32 * ip; - const uint is = 8 * ip + il / 16; - - const uint y_idx = i * QUANT_K + 128 * ip + il; - - const uint ql_idx = 32 * ip + il; - const uint8_t qs = data_a[i].qs[32 * ip + il]; - - FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); - FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); - data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); - data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); - data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); - data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4)); - } -} -""" -dequant_q3_K_body = """ -layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { - const uint i = uint(gl_WorkGroupID.x * 256 + wgy); - if (i >= p.M * p.K / QUANT_K) { - return; - } - - const uint r = gl_LocalInvocationID.x / 4; - const uint tid = r / 2; - const uint is0 = r % 2; - const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4); - const uint n = tid / 4; - const uint j = tid - 4*n; - - const uint8_t m = uint8_t(1 << (4*n + j)); - const uint is = 8*n + 2*j + is0; - const uint shift = 2*j; - - const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : - is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : - is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) : - (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4)); - const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); - const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); - - const uint y_idx = i * QUANT_K + 128 * n + 32 * j; - const uint qs_idx = 32*n; - - for (uint l = l0; l < l0 + 4; ++l) { - data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); - } - } -} -""" -dequant_q4_K_body = """ -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { - const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { - return; - } - - const uint tid = gl_LocalInvocationID.x; - const uint il = tid / 8; - const uint ir = tid % 8; - const uint is = 2 * il; - const uint n = 4; - - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); - - const uint y_idx = i * QUANT_K + 64 * il + n * ir; - const uint qs_idx = 32*il + n * ir; - - uint8_t sc; - uint8_t m; - if (is < 4) { - sc = uint8_t(data_a[i].scales[is] & 63); - m = uint8_t(data_a[i].scales[is + 4] & 63); - } else { - sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4)); - m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4)); - } - const FLOAT_TYPE d1 = dall * sc; - const FLOAT_TYPE m1 = dmin * m; - - if (is < 4) { - sc = uint8_t(data_a[i].scales[is + 1] & 63); - m = uint8_t(data_a[i].scales[is + 5] & 63); - } else { - sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4)); - m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4)); - } - const FLOAT_TYPE d2 = dall * sc; - const FLOAT_TYPE m2 = dmin * m; - - [[unroll]] for (uint l = 0; l < n; ++l) { - data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1); - data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2); - } - } -} -""" -dequant_q5_K_body = """ -layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { - const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { - return; - } - - const uint tid = gl_LocalInvocationID.x; - const uint il = tid / 16; - const uint ir = tid % 16; - const uint is = 2 * il; - - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); - - const uint y_idx = i * QUANT_K + 64 * il + 2 * ir; - const uint qs_idx = 32*il + 2 * ir; - const uint qh_idx = 2 * ir; - - uint8_t sc; - uint8_t m; - if (is < 4) { - sc = uint8_t(data_a[i].scales[is] & 63); - m = uint8_t(data_a[i].scales[is + 4] & 63); - } else { - sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4)); - m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4)); - } - const FLOAT_TYPE d1 = dall * sc; - const FLOAT_TYPE m1 = dmin * m; - - if (is < 4) { - sc = uint8_t(data_a[i].scales[is + 1] & 63); - m = uint8_t(data_a[i].scales[is + 5] & 63); - } else { - sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4)); - m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4)); - } - const FLOAT_TYPE d2 = dall * sc; - const FLOAT_TYPE m2 = dmin * m; - - const uint8_t hm1 = uint8_t(1 << (2 * il )); - const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); - data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx ] & 0xF) + (((data_a[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); - data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] & 0xF) + (((data_a[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); - data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx ] >> 4) + (((data_a[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); - data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] >> 4) + (((data_a[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); - } -} -""" -dequant_q6_K_body = """ -layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { - const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { - return; - } - const uint tid = gl_LocalInvocationID.x; - const uint ip = tid / 32; - const uint il = tid - 32 * ip; - const uint is = 8 * ip + il / 16; - - const uint y_idx = i * QUANT_K + 128 * ip + il; - - const uint ql_idx = 64 * ip + il; - const uint8_t qh = data_a[i].qh[32 * ip + il]; - - const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); - - data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); - data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); - data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); - data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); - } -} -""" - -# Mul Mat Vec -mul_mat_vec_head = """#version 450 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_8bit_storage : require - -#ifdef MUL_MAT_ID -#define EXPERT_COUNT 8 -#endif -""" - - -mul_mat_vec_layout = """ -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; -#ifdef MUL_MAT_ID -layout (binding = 3) readonly buffer IDS {int data_ids[];}; -#endif - -layout (push_constant) uniform parameter -{ - uint ncols; - uint stride_a; - uint stride_b; - uint stride_d; - - uint ne02; - uint ne12; - uint broadcast2; - uint broadcast3; - - uint batch_stride_a; - uint batch_stride_b; - uint batch_stride_d; - -#ifdef MUL_MAT_ID - uint expert_stride_a; - uint expert_stride_b0; - uint expert_stride_b1; - uint expert_stride_d0; - uint expert_stride_d1; - - uint ne11; - uint nei0; - uint nbi1; - uint n_as; -#endif -} p; -""" - -mul_mat_vec_body = """ -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 0) const uint BLOCK_SIZE = 32; - -shared FLOAT_TYPE tmp[BLOCK_SIZE]; - -void main() { - const uint row = gl_WorkGroupID.x; - const uint tid = gl_LocalInvocationID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; - - const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; - - tmp[tid] = FLOAT_TYPE(0.0f); - - [[unroll]] for (uint i = 0; i < p.ncols/BLOCK_SIZE; i += 2) { - const uint col = i*BLOCK_SIZE + 2*tid; - const uint ib = (row*p.ncols + col)/QUANT_K; // block index - const uint iqs = (col%QUANT_K)/QUANT_R; // quant index - const uint iybs = col - col%QUANT_K; // y block start index - - vec2 v = dequantize(ib, iqs, a_offset / QUANT_K); - - // matrix multiplication - tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) + - FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]); - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} -""" - -# K-quants -mul_mat_vec_q2_K_body = """ -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -shared FLOAT_TYPE tmp[32]; - -void main() { - const uint row = gl_WorkGroupID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; - - const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = tid - step*v_im; // 0...15 or 0...7 - - const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 - const uint q_offset = 32*v_im + l0; - const uint s_offset = 8*v_im; - const uint y_offset = 128*v_im + l0; - - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const uint y_idx = i * QUANT_K + y_offset; - - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); - - FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); - FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3) - + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3); - sum2 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF) - + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF); - } - tmp[16 * ix + tid] += dall * sum1 - dmin * sum2; - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} -""" -mul_mat_vec_q3_K_body = """ -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -shared FLOAT_TYPE tmp[32]; - -void main() { - const uint row = gl_WorkGroupID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; - - const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = tid - step*v_im; // 0...15 or 0...7 - - const uint8_t m = uint8_t(1 << (4 * v_im)); - - const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 - const uint q_offset = 32*v_im + l0; - const uint y_offset = 128*v_im + l0; - - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - - const uint s_shift = 4 * v_im; - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const uint y_idx = i * QUANT_K + y_offset; - - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)) - + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)); - } - tmp[16 * ix + tid] += d * sum; - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} -""" -mul_mat_vec_q4_K_body = """ -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -shared FLOAT_TYPE tmp[32]; - -void main() { - const uint row = gl_WorkGroupID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; - - const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 - - const uint il = tid/step; // 0...3 - const uint ir = tid - step*il; // 0...7 or 0...3 - const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 - - const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const uint v_in = il % 2; - - const uint l0 = n * (2 * ir + v_in); // 0...15 - const uint q_offset = 32*v_im + l0; - const uint y_offset = 64*v_im + l0; - - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); - - const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f); - const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f); - const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f); - const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f); - const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); - const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); - const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); - const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); - -#if K_QUANTS_PER_ITERATION == 2 - const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf); - const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf); - const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4); - const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4); - const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf); - const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf); - const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4); - const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4); - - const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3); - const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7); - const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11); - const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15); - const FLOAT_TYPE smin = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y1_idx ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7 - + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7 - + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7 - + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7 - ); - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); -#else - const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - - const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx ]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1); - const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3); - const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx ]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5); - const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7); - const FLOAT_TYPE smin = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7 - + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7 - ); - - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); -#endif - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} -""" -mul_mat_vec_q5_K_body = """ -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -shared FLOAT_TYPE tmp[32]; - -void main() { - const uint row = gl_WorkGroupID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; - - const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - - const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1 - - const uint il = tid/4; // 0...3 - const uint ir = tid - 4*il; // 0...7 or 0...3 - - const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const uint v_in = il % 2; - - const uint l0 = 4*ir + 2*v_in; // 0...15 - const uint q_offset = 32*v_im + l0; - const uint y_offset = 64*v_im + l0; - - const uint8_t hm1 = uint8_t(1 << (2*v_im)); - const uint8_t hm2 = uint8_t(hm1 << 4); - - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); - - const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f); - const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f); - const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f); - const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f); - const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); - const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); - const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); - const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); - - const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf); - const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf); - const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4); - const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4); - const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf); - const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf); - const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4); - const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4); - - const FLOAT_TYPE sx = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y1_idx ]) * (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)) - ); - const FLOAT_TYPE sy = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)) - ); - const FLOAT_TYPE sz = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y2_idx ]) * (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)) - ); - const FLOAT_TYPE sw = FLOAT_TYPE( - FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)) - + FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)) - ); - const FLOAT_TYPE smin = FLOAT_TYPE( - (FLOAT_TYPE(data_b[b_offset + y1_idx]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17])) * sc2 + (FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49])) * sc3 - + (FLOAT_TYPE(data_b[b_offset + y2_idx]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17])) * sc6 + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7 - ); - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} -""" -mul_mat_vec_q6_K_body = """ -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -shared FLOAT_TYPE tmp[32]; - -void main() { - const uint row = gl_WorkGroupID.x; - const uint batch_idx = gl_GlobalInvocationID.y; -#ifdef MUL_MAT_ID - const uint expert_idx1 = gl_GlobalInvocationID.z / p.nei0; - const uint expert_idx0 = gl_GlobalInvocationID.z % p.nei0; -#endif - - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; - -#ifdef MUL_MAT_ID - const uint expert_id = data_ids[expert_idx1 * p.nbi1 + expert_idx0]; -#endif - - const uint a_offset = -#ifdef MUL_MAT_ID - expert_id * p.expert_stride_a + -#endif - batch_idx_a * p.batch_stride_a; - const uint b_offset = -#ifdef MUL_MAT_ID - (expert_idx0 % p.ne11) * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_b; - const uint d_offset = -#ifdef MUL_MAT_ID - expert_idx0 * p.expert_stride_b0 + - expert_idx1 * p.expert_stride_b1 + -#endif - batch_idx * p.batch_stride_d; - - const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = tid - step*v_im; // 0...15 or 0...7 - -#if K_QUANTS_PER_ITERATION == 1 - const uint l0 = v_in; // 0...15 - const uint is = 0; -#else - const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 - const uint is = v_in / 4; -#endif - - const uint ql_offset = 64*v_im + l0; - const uint qh_offset = 32*v_im + l0; - const uint s_offset = 8*v_im + is; - const uint y_offset = 128*v_im + l0; - - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const uint y_idx = i * QUANT_K + y_offset; - - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - -#if K_QUANTS_PER_ITERATION == 1 - FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32); - tmp[16 * ix + tid] += sum; -#else - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 4; ++l) { - sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32) - + FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32); - } - tmp[16 * ix + tid] += sum; -#endif - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} -""" - -mul_mat_p021_src = """#version 450 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require - -#define BLOCK_SIZE 32 -#define FLOAT_TYPE float - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; - -layout (push_constant) uniform parameter -{ - uint ncols_x; - uint nrows_x; - uint nchannels_x; - uint nchannels_y; - uint b_offset; - uint d_offset; -} p; - -shared FLOAT_TYPE tmp[BLOCK_SIZE]; - -void main() { - const uint tid = gl_LocalInvocationID.x; - const uint row_x = gl_GlobalInvocationID.y; - const uint channel = gl_GlobalInvocationID.z; - const uint channel_x = channel / (p.nchannels_y / p.nchannels_x); - - const uint nrows_y = p.ncols_x; - const uint nrows_dst = p.nrows_x; - const uint row_dst = row_x; - - tmp[tid] = FLOAT_TYPE(0.0f); - - for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { - const uint col_x = col_x0 + tid; - - if (col_x >= p.ncols_x) { - break; - } - - // x is transposed and permuted - const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); - - const uint row_y = col_x; - - // y is not transposed but permuted - const uint iy = channel*nrows_y + row_y; - - tmp[tid] += xi * FLOAT_TYPE(data_b[iy]); - } - - // dst is not transposed and not permuted - const uint idst = channel*nrows_dst + row_dst; - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - - if (tid == 0) { - dst[idst] = tmp[0]; - } -} -""" - - -mul_mat_nc_src = """#version 450 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require - -#define BLOCK_SIZE 32 -#define FLOAT_TYPE float - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; - -layout (push_constant) uniform parameter -{ - uint ncols_x; - uint nrows_x; - uint row_stride_x; - uint channel_stride_x; - uint channel_x_divisor; - uint b_offset; - uint d_offset; -} p; - -shared FLOAT_TYPE tmp[BLOCK_SIZE]; - -void main() { - const uint tid = gl_LocalInvocationID.x; - const uint row_x = gl_GlobalInvocationID.y; - const uint channel = gl_GlobalInvocationID.z; - const uint channel_x = channel / p.channel_x_divisor; - - const uint nrows_y = p.ncols_x; - const uint nrows_dst = p.nrows_x; - const uint row_dst = row_x; - - const uint idst = channel*nrows_dst + row_dst; - - tmp[tid] = 0.0f; - - for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { - const uint col_x = col_x0 + tid; - - if (col_x >= p.ncols_x) { - break; - } - - const uint row_y = col_x; - - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel*nrows_y + row_y; - - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); - - tmp[tid] += xi * FLOAT_TYPE(data_b[iy]); - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - - if (tid == 0) { - dst[idst] = tmp[0]; - } -} -""" - -generic_head = """ -#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout (push_constant) uniform parameter -{ - uint KX; - uint KY; - float param1; - float param2; -} p; -""" - -generic_unary_op_head = """#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout (push_constant) uniform parameter -{ - uint ne; - uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; - uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; - uint d_offset; - float param1; float param2; -} p;""" - -generic_unary_op_layout = """ -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};""" - -generic_unary_op_funcs = """ -uint src0_idx(uint idx) { - const uint i03 = idx / (p.ne02*p.ne01*p.ne00); - const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; - const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); - const uint i02_offset = i02*p.ne01*p.ne00; - const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; - const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; - return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; -} - -uint dst_idx(uint idx) { - const uint i13 = idx / (p.ne12*p.ne11*p.ne10); - const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; - const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); - const uint i12_offset = i12*p.ne11*p.ne10; - const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; - const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; - return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; -}""" - -generic_unary_op_main = """ -void main() { - if (gl_GlobalInvocationID.x >= p.ne) { - return; - } -""" - -generic_unary_op_combined = f"{generic_unary_op_head}\n{generic_unary_op_layout}\n{generic_unary_op_funcs}\n{generic_unary_op_main}" - -generic_binary_op_head = """#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout (push_constant) uniform parameter -{ - uint ne; - uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; - uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; - uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; - uint d_offset; - float param1; float param2; -} p;""" - -generic_binary_op_layout = """ -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};""" - -generic_binary_op_funcs = """ -uint src0_idx(uint idx) { - const uint i03 = idx / (p.ne02*p.ne01*p.ne00); - const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; - const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); - const uint i02_offset = i02*p.ne01*p.ne00; - const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; - const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; - return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; -} - -uint src1_idx(uint idx) { - const uint i03 = idx / (p.ne02*p.ne01*p.ne00); - const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; - const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); - const uint i02_offset = i02*p.ne01*p.ne00; - const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; - const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; - - return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10; -} - -uint dst_idx(uint idx) { - const uint i23 = idx / (p.ne22*p.ne21*p.ne20); - const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; - const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20); - const uint i22_offset = i22*p.ne21*p.ne20; - const uint i21 = (idx - i23_offset - i22_offset) / p.ne20; - const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20; - return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20; -}""" - -generic_binary_op_main = """ -void main() { - if (gl_GlobalInvocationID.x >= p.ne) { - return; - } -""" - -generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_layout}\n{generic_binary_op_funcs}\n{generic_binary_op_main}" - -# MUL F32 -mul_body = """ - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); -} -""" - -# ADD -add_body = """ - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); -} -""" - -# SCALE -scale_body = """ - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(p.param1)); -} -""" - -# SQR -sqr_body = """ - const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]); - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val * val); -} -""" - -# CLAMP -clamp_body = """ - const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]); - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); -} -""" - -# CPY -cpy_end = """ - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]); -} -""" -# Causes an optimization error otherwise -cpy_f16_f16_end = """ - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)]; -} -""" - -# GET_ROWS -get_rows_float_body = """ -void main() { - const uint i00 = gl_GlobalInvocationID.x; - const uint i10 = gl_GlobalInvocationID.y; - const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; - const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; - - if (i00 >= p.ne00) { - return; - } - - const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; - - const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; - const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; - -#ifndef OPTIMIZATION_ERROR_WORKAROUND - data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]); -#else - data_d[d_offset + i00] = data_a[a_offset + i00]; -#endif -} -""" - -get_rows_body = """ -void main() { - const uint i00 = (gl_GlobalInvocationID.x)*2; - const uint i10 = gl_GlobalInvocationID.y; - const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; - const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; - - if (i00 >= p.ne00) { - return; - } - - const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; - - const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; - const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; - - const uint ib = a_offset + i00/QUANT_K; // block index - const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index - const uint iybs = i00 - i00%QUANT_K; // dst block start index - const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; - - vec2 v = dequantize(ib, iqs, 0); - - data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); - data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); -} -""" - -# UNARY -gelu_body = """ -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - const uint i = gl_GlobalInvocationID.x; - - if (i >= p.KX) { - return; - } - - const float xi = float(data_a[i]); - const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); - data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1))); -} -""" - -silu_body = """ -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const uint i = gl_GlobalInvocationID.x; - - if (i >= p.KX) { - return; - } - - const float xi = float(data_a[i]); - data_d[i] = D_TYPE(xi / (1.0f + exp(-xi))); -} -""" - -relu_body = """ -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const uint i = gl_GlobalInvocationID.x; - - if (i >= p.KX) { - return; - } - - data_d[i] = max(float(data_a[i]), 0); -} -""" - -# DIAG_MASK_INF -diag_mask_inf_head = """#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout (push_constant) uniform parameter -{ - uint ncols; - uint rows_per_channel; - uint n_past; -} p; -""" -diag_mask_inf_body = """ -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const uint col = gl_GlobalInvocationID.y; - const uint row = gl_GlobalInvocationID.x; - - if (col >= p.ncols) { - return; - } - - const uint i = row*p.ncols + col; - if (col > p.n_past + row % p.rows_per_channel) { - data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000)); - } else { - data_d[i] = D_TYPE(data_a[i]); - } -} -""" - -# NORMS -norm_body = """ -#extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 512 - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -shared vec2 sum[BLOCK_SIZE]; - -void main() { - const uint row = gl_WorkGroupID.x; - const uint tid = gl_LocalInvocationID.x; - - sum[tid] = vec2(0.0f, 0.0f); - - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const float xi = float(data_a[row*p.KX + col]); - sum[tid].x += xi; - sum[tid].y += xi * xi; - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { - if (tid < s) { - sum[tid] += sum[tid + s]; - } - barrier(); - } - - const float mean = sum[0].x / p.KX; - const float var = sum[0].y / p.KX - mean * mean; - const float inv_std = inversesqrt(var + p.param1); - - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); - } -} -""" - -rms_norm_body = """ -#extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 512 - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -shared FLOAT_TYPE sum[BLOCK_SIZE]; - -void main() { - const uint row = gl_WorkGroupID.x; - const uint tid = gl_LocalInvocationID.x; - - sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp - - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); - sum[tid] += xi * xi; - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { - if (tid < s) { - sum[tid] += sum[tid + s]; - } - barrier(); - } - - const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX); - const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); - } -} -""" - -# SOFT_MAX -soft_max_head = """ -#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout (push_constant) uniform parameter -{ - uint KX; - uint KY; - uint KZ; - float scale; - float max_bias; - float m0; - float m1; - uint n_head_log2; -} p; -""" - -soft_max_body = """ -#extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 512 - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; -layout (binding = 2) readonly buffer Z {C_TYPE data_c[];}; -layout (binding = 3) buffer D {D_TYPE data_d[];}; - -shared FLOAT_TYPE vals[BLOCK_SIZE]; - -void main() { - const uint tid = gl_LocalInvocationID.x; - const uint rowx = gl_WorkGroupID.x; - const uint rowy = rowx % p.KY; - - float slope = 0.0f; - - // ALiBi - if (p.max_bias > 0.0f) { - const uint h = rowx/p.KY; // head index - - const float base = h < p.n_head_log2 ? p.m0 : p.m1; - const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; - - slope = pow(base, exp); - } - - // Find max - vals[tid] = uintBitsToFloat(0xFF800000); - - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * FLOAT_TYPE(data_c[col]) : 0.0f)); - } - - barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { - if (tid < s) { - vals[tid] = max(vals[tid], vals[tid + s]); - } - barrier(); - } - - const FLOAT_TYPE max_val = vals[0]; - barrier(); - - // Sum up values - vals[tid] = FLOAT_TYPE(0.0f); - - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const uint i = rowx * p.KX + col; - const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); - vals[tid] += val; - data_d[i] = D_TYPE(val); - } - - barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { - if (tid < s) { - vals[tid] += vals[tid + s]; - } - barrier(); - } - - const D_TYPE divisor = D_TYPE(vals[0]); - - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[rowx*p.KX + col] /= divisor; - } -} -""" - -# ROPE -rope_src = """ -#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {int data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; - -layout (push_constant) uniform parameter { - uint ncols; - float freq_scale; - uint p_delta_rows; - float freq_base; - float ext_factor; - float attn_factor; - float corr_dims[4]; -} p; - -float rope_yarn_ramp(const float low, const float high, const uint i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { - float mscale = p.attn_factor; - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = p.freq_scale * theta_extrap; - float theta = theta_interp; - if (p.ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); - } - cos_theta = cos(theta) * mscale; - sin_theta = sin(theta) * mscale; -} - -void main() { - const uint col = gl_GlobalInvocationID.y * 2; - const uint row = gl_GlobalInvocationID.x; - - if (col >= p.ncols) { - return; - } - - const uint i = row*p.ncols + col; - const uint i2 = row/p.p_delta_rows; - - const int pos = data_b[i2]; - const float theta_base = pos * pow(p.freq_base, -float(col)/p.ncols); - - float cos_theta, sin_theta; - rope_yarn(theta_base, col, cos_theta, sin_theta); - - const float x0 = float(data_a[i + 0]); - const float x1 = float(data_a[i + 1]); - - data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); -} -""" - -rope_neox_src = """ -#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {int data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; - -layout (push_constant) uniform parameter { - uint ncols; - uint ndims; - float freq_scale; - uint p_delta_rows; - float freq_base; - float ext_factor; - float attn_factor; - float corr_dims[4]; - float theta_scale; - float inv_ndims; -} p; - -float rope_yarn_ramp(const float low, const float high, const uint i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { - float mscale = p.attn_factor; - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = p.freq_scale * theta_extrap; - float theta = theta_interp; - if (p.ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); - } - cos_theta = cos(theta) * mscale; - sin_theta = sin(theta) * mscale; -} - -void main() { - const uint col = gl_GlobalInvocationID.y * 2; - const uint row = gl_GlobalInvocationID.x; - - if (col >= p.ncols) { - return; - } - - const uint ib = col / p.ndims; - const uint ic = col % p.ndims; - - if (ib > 0) { - const uint i = row*p.ncols + ib*p.ndims + ic; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - - const uint i = row*p.ncols + ib*p.ndims + ic/2; - const uint i2 = row/p.p_delta_rows; - - const float cur_rot = p.inv_ndims * ic - ib; - - const int pos = data_b[i2]; - const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f); - - float cos_theta, sin_theta; - rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta); - - const float x0 = float(data_a[i + 0]); - const float x1 = float(data_a[i + p.ndims/2]); - - data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[i + p.ndims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); -} -""" - -argsort_src = """ -#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) buffer D {int data_d[];}; - -layout (push_constant) uniform parameter { - uint ncols; - bool ascending; -} p; - -void swap(uint idx0, uint idx1) { - int tmp = data_d[idx0]; - data_d[idx0] = data_d[idx1]; - data_d[idx1] = tmp; -} - -void main() { - // bitonic sort - const int col = int(gl_LocalInvocationID.x); - const uint row = gl_WorkGroupID.y; - - if (col >= p.ncols) { - return; - } - - const uint a_idx = row * p.ncols; - const uint d_idx = row * p.ncols; - - // initialize indices - if (col < p.ncols) { - data_d[col] = col; - } - barrier(); - - for (uint k = 2; k <= p.ncols; k *= 2) { - for (uint j = k / 2; j > 0; j /= 2) { - const uint ixj = col ^ j; - if (ixj > col) { - if ((col & k) == 0) { - if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) { - swap(d_idx + col, d_idx + ixj); - } - } else { - if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) { - swap(d_idx + col, d_idx + ixj); - } - } - } - barrier(); - } - } -} -""" - -GLSLC = "glslc" - -VK_NUM_TYPES = 16 - -GGML_TYPE_F32 = 0 -GGML_TYPE_F16 = 1 -GGML_TYPE_Q4_0 = 2 -GGML_TYPE_Q4_1 = 3 -GGML_TYPE_Q5_0 = 6 -GGML_TYPE_Q5_1 = 7 -GGML_TYPE_Q8_0 = 8 -GGML_TYPE_Q8_1 = 9 -GGML_TYPE_Q2_K = 10 -GGML_TYPE_Q3_K = 11 -GGML_TYPE_Q4_K = 12 -GGML_TYPE_Q5_K = 13 -GGML_TYPE_Q6_K = 14 -GGML_TYPE_Q8_K = 15 - - -type_names = { - GGML_TYPE_F32: "f32", - GGML_TYPE_F16: "f16", - GGML_TYPE_Q4_0: "q4_0", - GGML_TYPE_Q4_1: "q4_1", - GGML_TYPE_Q5_0: "q5_0", - GGML_TYPE_Q5_1: "q5_1", - GGML_TYPE_Q8_0: "q8_0", - GGML_TYPE_Q8_1: "q8_1", - GGML_TYPE_Q2_K: "q2_K", - GGML_TYPE_Q3_K: "q3_K", - GGML_TYPE_Q4_K: "q4_K", - GGML_TYPE_Q5_K: "q5_K", - GGML_TYPE_Q6_K: "q6_K", - GGML_TYPE_Q8_K: "q8_K", -} - -K_QUANTS_PER_ITERATION = 2 - -ASYNCIO_CONCURRENCY = 64 - -output_dir = gettempdir() - -lock = asyncio.Lock() -shader_fnames = [] - - -async def string_to_spv(name, code, defines, fp16=True): - f = NamedTemporaryFile(mode="w", delete=False) - f.write(code) - f.flush() - - name = f"{name}{'_fp32' if not fp16 else ''}" - fname = os.path.join(output_dir, f"{name}.comp") - - cmd = [GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", f.name, "-o", fname] - - cmd.extend([f"-D{key}={value}" for key, value in defines.items()]) - - proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) - - stdout, stderr = await proc.communicate() - - stdout = stdout.decode() - error = stderr.decode() - - if proc.returncode: - # Generate preprocessed code - cmd = [GLSLC, "-E", f.name] - cmd.extend([f"-D{key}={value}" for key, value in defines.items()]) - - proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) - - stdout, stderr = await proc.communicate() - - logger.info(" ".join(cmd)) - - if proc.returncode: - raise RuntimeError(f"{name=} {f.name=} {stdout=} {stderr=}") - - preprocessed_code = stdout.decode() - - cmd.extend([f"-D{key}={value}" for key, value in defines.items()]) - code_with_lines = "\n".join([f"{i + 1}: {line}" for i, line in enumerate(preprocessed_code.splitlines())]) - logger.error(f"cannot compile {name}\n\n{code_with_lines}\n\n{error}") - f.close() - os.remove(f.name) - sys.exit(proc.returncode) - - f.close() - os.remove(f.name) - - async with lock: - shader_fnames.append((name, fname)) - - -async def main(): - logger.info("ggml_vulkan: Generating and compiling shaders to SPIR-V") - - tasks = [] - - stream = [] - - for fp16 in (False, True): - # mulmat - if fp16: - shader_float_type = shader_f16 - load_vec = "8" - vec_type_f16 = "f16mat2x4" - vec_type = "mat2x4" - else: - shader_float_type = shader_f32 - load_vec = "4" - vec_type_f16 = "f16vec4" - vec_type = "vec4" - - stream.clear() - stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2)) - tasks.append(string_to_spv("matmul_f32", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f32_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - tasks.append(string_to_spv("matmul_f16", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float16_t", "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) - - tasks.append(string_to_spv("matmul_f16_f32", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_f16_f32_aligned", "".join(stream), {"LOAD_VEC_A": 1, "LOAD_VEC_B": load_vec, "A_TYPE": "float16_t", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - stream.clear() - stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2)) - tasks.append(string_to_spv("matmul_q4_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_q4_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - stream.clear() - stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2)) - tasks.append(string_to_spv("matmul_q4_1_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_q4_1_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - stream.clear() - stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2)) - tasks.append(string_to_spv("matmul_q5_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_q5_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - stream.clear() - stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2)) - tasks.append(string_to_spv("matmul_q5_1_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_q5_1_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - stream.clear() - stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2)) - tasks.append(string_to_spv("matmul_q8_0_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_q8_0_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - stream.clear() - stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q2_K_defines, mulmat_body1, mulmat_load_q2_K, mulmat_body2)) - tasks.append(string_to_spv("matmul_q2_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q2_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_q2_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q2_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - stream.clear() - stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q3_K_defines, mulmat_body1, mulmat_load_q3_K, mulmat_body2)) - tasks.append(string_to_spv("matmul_q3_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q3_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_q3_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q3_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - stream.clear() - stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_K_defines, mulmat_body1, mulmat_load_q4_K, mulmat_body2)) - tasks.append(string_to_spv("matmul_q4_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q4_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_q4_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - stream.clear() - stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_K_defines, mulmat_body1, mulmat_load_q5_K, mulmat_body2)) - tasks.append(string_to_spv("matmul_q5_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q5_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_q5_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - stream.clear() - stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q6_K_defines, mulmat_body1, mulmat_load_q6_K, mulmat_body2)) - tasks.append(string_to_spv("matmul_q6_k_f32", "".join(stream), {"LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("matmul_q6_k_f32_aligned", "".join(stream), {"LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # MUL_MAT_ID - # stream.clear() - # stream.extend((mulmat_head, shader_float_type, mulmat_body1, mulmat_load_scalar, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # tasks.append(string_to_spv("matmul_id_f16", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_f16_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16)) - - # tasks.append(string_to_spv("matmul_id_f16_f32", "".join(stream), {"MUL_MAT_ID": "1", "A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_f16_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_0_defines, mulmat_body1, mulmat_load_q4_0, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q4_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q4_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_1_defines, mulmat_body1, mulmat_load_q4_1, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q4_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q4_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_0_defines, mulmat_body1, mulmat_load_q5_0, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q5_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q5_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_1_defines, mulmat_body1, mulmat_load_q5_1, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q5_1_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_1", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q5_1_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_1", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q8_0_defines, mulmat_body1, mulmat_load_q8_0, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q8_0_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q8_0", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q8_0_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q8_0", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q2_K_defines, mulmat_body1, mulmat_load_q2_K, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q2_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q2_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q2_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q2_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q3_K_defines, mulmat_body1, mulmat_load_q3_K, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q3_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q3_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q3_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q3_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q4_K_defines, mulmat_body1, mulmat_load_q4_K, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q4_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q4_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q4_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q4_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q5_K_defines, mulmat_body1, mulmat_load_q5_K, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q5_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q5_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q5_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q5_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # stream.clear() - # stream.extend((mulmat_head, shader_int8_ext, shader_float_type, shader_q6_K_defines, mulmat_body1, mulmat_load_q6_K, mulmat_body2)) - # tasks.append(string_to_spv("matmul_id_q6_k_f32", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "A_TYPE": "block_q6_K", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - # tasks.append(string_to_spv("matmul_id_q6_k_f32_aligned", "".join(stream), {"MUL_MAT_ID": "1", "LOAD_VEC_A": 2, "LOAD_VEC_B": load_vec, "A_TYPE": "block_q6_K", "B_TYPE": vec_type, "D_TYPE": "float"}, fp16)) - - # Shaders where precision is needed, so no fp16 version - - # mul mat vec - for i in range(0, VK_NUM_TYPES): - stream.clear() - stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32)) - - if i == GGML_TYPE_F16: - stream.extend((shader_f16_defines, mul_mat_vec_layout, shader_float_dequant_func, mul_mat_vec_body)) - elif i == GGML_TYPE_Q4_0: - stream.extend((shader_q4_0_defines, mul_mat_vec_layout, shader_q4_0_dequant_func, mul_mat_vec_body)) - elif i == GGML_TYPE_Q4_1: - stream.extend((shader_q4_1_defines, mul_mat_vec_layout, shader_q4_1_dequant_func, mul_mat_vec_body)) - elif i == GGML_TYPE_Q5_0: - stream.extend((shader_q5_0_defines, mul_mat_vec_layout, shader_q5_0_dequant_func, mul_mat_vec_body)) - elif i == GGML_TYPE_Q5_1: - stream.extend((shader_q5_1_defines, mul_mat_vec_layout, shader_q5_1_dequant_func, mul_mat_vec_body)) - elif i == GGML_TYPE_Q8_0: - stream.extend((shader_q8_0_defines, mul_mat_vec_layout, shader_q8_0_dequant_func, mul_mat_vec_body)) - elif i == GGML_TYPE_Q2_K: - stream.extend((shader_q2_K_defines, mul_mat_vec_layout, mul_mat_vec_q2_K_body)) - elif i == GGML_TYPE_Q3_K: - stream.extend((shader_q3_K_defines, mul_mat_vec_layout, mul_mat_vec_q3_K_body)) - elif i == GGML_TYPE_Q4_K: - stream.extend((shader_q4_K_defines, mul_mat_vec_layout, mul_mat_vec_q4_K_body)) - elif i == GGML_TYPE_Q5_K: - stream.extend((shader_q5_K_defines, mul_mat_vec_layout, mul_mat_vec_q5_K_body)) - elif i == GGML_TYPE_Q6_K: - stream.extend((shader_q6_K_defines, mul_mat_vec_layout, mul_mat_vec_q6_K_body)) - else: - continue - - tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) - tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f16_f32", "".join(stream), {"B_TYPE": "float16_t", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) - - # tasks.append(string_to_spv(f"mul_mat_vec_id_{type_names[i]}_f32", "".join(stream), {"MUL_MAT_ID": "1", "B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION})) - - # Dequant shaders - for i in range(0, VK_NUM_TYPES): - stream.clear() - - stream.extend((dequant_head, shader_int8_ext, shader_f32)) - - if i == GGML_TYPE_F32: - stream.append(dequant_f32_body) - elif i == GGML_TYPE_Q4_0: - stream.extend((shader_q4_0_defines, dequant_q4_0_body)) - elif i == GGML_TYPE_Q4_1: - stream.extend((shader_q4_1_defines, dequant_q4_1_body)) - elif i == GGML_TYPE_Q5_0: - stream.extend((shader_q5_0_defines, dequant_q5_0_body)) - elif i == GGML_TYPE_Q5_1: - stream.extend((shader_q5_1_defines, dequant_q5_1_body)) - elif i == GGML_TYPE_Q8_0: - stream.extend((shader_q8_0_defines, dequant_q8_0_body)) - elif i == GGML_TYPE_Q2_K: - stream.extend((shader_q2_K_defines, dequant_q2_K_body)) - elif i == GGML_TYPE_Q3_K: - stream.extend((shader_q3_K_defines, dequant_q3_K_body)) - elif i == GGML_TYPE_Q4_K: - stream.extend((shader_q4_K_defines, dequant_q4_K_body)) - elif i == GGML_TYPE_Q5_K: - stream.extend((shader_q5_K_defines, dequant_q5_K_body)) - elif i == GGML_TYPE_Q6_K: - stream.extend((shader_q6_K_defines, dequant_q6_K_body)) - else: - continue - - tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"})) - - # get_rows - for i in range(0, VK_NUM_TYPES): - stream.clear() - stream.extend((generic_binary_op_head, shader_int8_ext, shader_f32)) - optimization_workaround = False - - if i == GGML_TYPE_F32: - stream.extend((shader_f32_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body)) - elif i == GGML_TYPE_F16: - stream.extend((shader_f16_defines, generic_binary_op_layout, generic_binary_op_funcs, get_rows_float_body)) - optimization_workaround = True - elif i == GGML_TYPE_Q4_0: - stream.extend((shader_q4_0_defines, generic_binary_op_layout, shader_q4_0_dequant_func, generic_binary_op_funcs, get_rows_body)) - elif i == GGML_TYPE_Q4_1: - stream.extend((shader_q4_1_defines, generic_binary_op_layout, shader_q4_1_dequant_func, generic_binary_op_funcs, get_rows_body)) - elif i == GGML_TYPE_Q5_0: - stream.extend((shader_q5_0_defines, generic_binary_op_layout, shader_q5_0_dequant_func, generic_binary_op_funcs, get_rows_body)) - elif i == GGML_TYPE_Q5_1: - stream.extend((shader_q5_1_defines, generic_binary_op_layout, shader_q5_1_dequant_func, generic_binary_op_funcs, get_rows_body)) - elif i == GGML_TYPE_Q8_0: - stream.extend((shader_q8_0_defines, generic_binary_op_layout, shader_q8_0_dequant_func, generic_binary_op_funcs, get_rows_body)) - else: - continue - - if optimization_workaround: - tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"})) - else: - tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float16_t"})) - tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "int", "D_TYPE": "float"})) - - tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", mul_mat_p021_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", mul_mat_nc_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"})) - - # Norms - tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"})) - - tasks.append(string_to_spv("cpy_f32_f32", f"{generic_unary_op_combined}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("cpy_f32_f16", f"{generic_unary_op_combined}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"})) - tasks.append(string_to_spv("cpy_f16_f16", f"{generic_unary_op_combined}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) - - tasks.append(string_to_spv("add_f32", f"{generic_binary_op_combined}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - - tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {})) - tasks.append(string_to_spv("mul_f32", f"{generic_binary_op_combined}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - - tasks.append(string_to_spv("scale_f32", f"{generic_unary_op_combined}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - - tasks.append(string_to_spv("sqr_f32", f"{generic_unary_op_combined}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - - tasks.append(string_to_spv("clamp_f32", f"{generic_unary_op_combined}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"})) - - tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("relu_f32", f"{generic_head}\n{shader_f32}\n{relu_body}", {"A_TYPE": "float", "D_TYPE": "float"})) - - tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"})) - - tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"})) - - tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) - - tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"})) - tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) - - tasks.append(string_to_spv("argsort_f32", argsort_src, {"A_TYPE": "float"})) - - # Helper to decorate tasks with semaphore acquisition. - async def withSemaphore(sem, task): - async with sem: - return await task - - # Run tasks concurrently guarded by a concurrency limit. - sem = asyncio.Semaphore(ASYNCIO_CONCURRENCY) - await asyncio.gather(*(withSemaphore(sem, task) for task in tasks)) - - with open("ggml-vulkan-shaders.hpp", "w") as f: - f.write("#include \n\n") - for name, path in sorted(shader_fnames): - - with open(path, "rb") as spv: - counter = 0 - newline_counter = 0 - f.write(f"unsigned char {name}_data[] = {{\n") - for val in spv.read(): - f.write(f"0x{val:02x},") - newline_counter += 1 - counter += 1 - if newline_counter >= 12: - newline_counter = 0 - f.write("\n") - f.write("\n};\n") - f.write(f"const uint64_t {name}_len = {counter};\n\n") - os.remove(path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator") - - parser.add_argument("--glslc", help="Path to glslc") - parser.add_argument("--verbose", action="store_true", help="increase output verbosity") - - args = parser.parse_args() - - logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - - if args.glslc: - GLSLC = args.glslc - - asyncio.run(main()) \ No newline at end of file diff --git a/src/vulkan-shaders/add.comp b/src/vulkan-shaders/add.comp index 8475b0119..3974845d6 100644 --- a/src/vulkan-shaders/add.comp +++ b/src/vulkan-shaders/add.comp @@ -4,9 +4,11 @@ #include "generic_binary_head.comp" void main() { - if (gl_GlobalInvocationID.x >= p.ne) { + const uint idx = get_idx(); + + if (idx >= p.ne) { return; } - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)])); } diff --git a/src/vulkan-shaders/clamp.comp b/src/vulkan-shaders/clamp.comp index ca272e227..7071302a4 100644 --- a/src/vulkan-shaders/clamp.comp +++ b/src/vulkan-shaders/clamp.comp @@ -4,10 +4,12 @@ #include "generic_unary_head.comp" void main() { - if (gl_GlobalInvocationID.x >= p.ne) { + const uint idx = get_idx(); + + if (idx >= p.ne) { return; } - const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]); - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); + const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); } diff --git a/src/vulkan-shaders/concat.comp b/src/vulkan-shaders/concat.comp new file mode 100644 index 000000000..08ab5514b --- /dev/null +++ b/src/vulkan-shaders/concat.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + const int dim = p.param3; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne22*p.ne21*p.ne20); + const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20; + const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20); + const uint i2_offset = i2*p.ne21*p.ne20; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne20; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20; + + uint o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03)); + + const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10; + const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20; + + const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]); +#else + data_d[p.d_offset + dst_idx] = is_src0 ? data_a[src0_idx] : data_b[src1_idx]; +#endif +} diff --git a/src/vulkan-shaders/copy.comp b/src/vulkan-shaders/copy.comp index efb55876e..c26917c0f 100644 --- a/src/vulkan-shaders/copy.comp +++ b/src/vulkan-shaders/copy.comp @@ -4,13 +4,15 @@ #include "generic_unary_head.comp" void main() { - if (gl_GlobalInvocationID.x >= p.ne) { + const uint idx = get_idx(); + + if (idx >= p.ne) { return; } #ifndef OPTIMIZATION_ERROR_WORKAROUND - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]); #else - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)]; + data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)]; #endif } diff --git a/src/vulkan-shaders/div.comp b/src/vulkan-shaders/div.comp index 8ee4bfc73..8cfce58b1 100644 --- a/src/vulkan-shaders/div.comp +++ b/src/vulkan-shaders/div.comp @@ -4,9 +4,11 @@ #include "generic_binary_head.comp" void main() { - if (gl_GlobalInvocationID.x >= p.ne) { + const uint idx = get_idx(); + + if (idx >= p.ne) { return; } - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) / FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)])); } diff --git a/src/vulkan-shaders/gelu.comp b/src/vulkan-shaders/gelu.comp index 9fe807cce..4cc7a68ca 100644 --- a/src/vulkan-shaders/gelu.comp +++ b/src/vulkan-shaders/gelu.comp @@ -13,7 +13,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; void main() { const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - const uint i = gl_GlobalInvocationID.x; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; if (i >= p.KX) { return; diff --git a/src/vulkan-shaders/gelu_quick.comp b/src/vulkan-shaders/gelu_quick.comp new file mode 100644 index 000000000..e6e6fcfd2 --- /dev/null +++ b/src/vulkan-shaders/gelu_quick.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_QUICK_COEF = -1.702f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x)))); +} diff --git a/src/vulkan-shaders/generic_binary_head.comp b/src/vulkan-shaders/generic_binary_head.comp index ab45d2564..b6beaff1c 100644 --- a/src/vulkan-shaders/generic_binary_head.comp +++ b/src/vulkan-shaders/generic_binary_head.comp @@ -7,7 +7,7 @@ layout (push_constant) uniform parameter uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; uint d_offset; - float param1; float param2; + float param1; float param2; int param3; } p; layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; @@ -16,6 +16,10 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + uint src0_idx(uint idx) { const uint i03 = idx / (p.ne02*p.ne01*p.ne00); const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; diff --git a/src/vulkan-shaders/generic_unary_head.comp b/src/vulkan-shaders/generic_unary_head.comp index de08de7cd..eacdefc7d 100644 --- a/src/vulkan-shaders/generic_unary_head.comp +++ b/src/vulkan-shaders/generic_unary_head.comp @@ -14,6 +14,10 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + uint src0_idx(uint idx) { const uint i03 = idx / (p.ne02*p.ne01*p.ne00); const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; diff --git a/src/vulkan-shaders/group_norm.comp b/src/vulkan-shaders/group_norm.comp new file mode 100644 index 000000000..5ad9b28da --- /dev/null +++ b/src/vulkan-shaders/group_norm.comp @@ -0,0 +1,66 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared float tmp[BLOCK_SIZE]; + +void main() { + const uint group_size = p.KX; + const float eps = p.param1; + + const uint tid = gl_LocalInvocationID.x; + const uint start = gl_WorkGroupID.x * group_size + tid; + const uint end = start + group_size; + + tmp[tid] = 0.0f; + + // Calculate mean + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + tmp[tid] += float(data_a[col]); + } + + // tmp up partial tmps and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float mean = tmp[0] / group_size; + barrier(); + tmp[tid] = 0.0f; + + // Calculate variance + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + const float xi = float(data_a[col]) - mean; + data_d[col] = D_TYPE(xi); + tmp[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float variance = tmp[0] / group_size; + const float scale = inversesqrt(variance + eps); + + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + data_d[col] *= D_TYPE(scale); + } +} diff --git a/src/vulkan-shaders/im2col.comp b/src/vulkan-shaders/im2col.comp new file mode 100644 index 000000000..4d48610a3 --- /dev/null +++ b/src/vulkan-shaders/im2col.comp @@ -0,0 +1,57 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint batch_offset; uint offset_delta; + uint IC; + uint IW; uint IH; + uint OW; uint OH; + uint KW; uint KH; + uint pelements; + uint CHW; + int s0; int s1; + int p0; int p1; + int d0; int d1; +} p; + +#include "types.comp" + +#define BLOCK_SIZE 256 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x; + if (i >= p.pelements) { + return; + } + + const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + const uint kx = i / ksize; + const uint kd = kx * ksize; + const uint ky = (i - kd) / p.OW; + const uint ix = i % p.OW; + + const uint oh = gl_GlobalInvocationID.y; + const uint batch = gl_GlobalInvocationID.z / p.IC; + const uint ic = gl_GlobalInvocationID.z % p.IC; + + const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; + const uint iih = oh * p.s1 + ky * p.d1 - p.p1; + + const uint offset_dst = + ((batch * p.OH + oh) * p.OW + ix) * p.CHW + + (ic * (p.KW * p.KH) + ky * p.KW + kx); + + if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) { + data_d[offset_dst] = D_TYPE(0.0f); + } else { + const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; + data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]); + } +} diff --git a/src/vulkan-shaders/leaky_relu.comp b/src/vulkan-shaders/leaky_relu.comp new file mode 100644 index 000000000..d90a99aea --- /dev/null +++ b/src/vulkan-shaders/leaky_relu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float val = float(data_a[i]); + data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); +} diff --git a/src/vulkan-shaders/mul.comp b/src/vulkan-shaders/mul.comp index bbb0aa1d2..bfb61c92d 100644 --- a/src/vulkan-shaders/mul.comp +++ b/src/vulkan-shaders/mul.comp @@ -4,9 +4,11 @@ #include "generic_binary_head.comp" void main() { - if (gl_GlobalInvocationID.x >= p.ne) { + const uint idx = get_idx(); + + if (idx >= p.ne) { return; } - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)])); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)])); } diff --git a/src/vulkan-shaders/norm.comp b/src/vulkan-shaders/norm.comp index 803dbdcb3..6627a50bd 100644 --- a/src/vulkan-shaders/norm.comp +++ b/src/vulkan-shaders/norm.comp @@ -14,7 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; shared vec2 sum[BLOCK_SIZE]; void main() { - const uint row = gl_WorkGroupID.x; + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; sum[tid] = vec2(0.0f, 0.0f); diff --git a/src/vulkan-shaders/pad.comp b/src/vulkan-shaders/pad.comp new file mode 100644 index 000000000..a465cd52b --- /dev/null +++ b/src/vulkan-shaders/pad.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne12*p.ne11*p.ne10); + const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; + const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10); + const uint i2_offset = i2*p.ne11*p.ne10; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; + + const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; + + const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + + data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f); +} diff --git a/src/vulkan-shaders/relu.comp b/src/vulkan-shaders/relu.comp index 7e5baa5b8..52a19b62a 100644 --- a/src/vulkan-shaders/relu.comp +++ b/src/vulkan-shaders/relu.comp @@ -11,7 +11,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; void main() { - const uint i = gl_GlobalInvocationID.x; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; if (i >= p.KX) { return; diff --git a/src/vulkan-shaders/rms_norm.comp b/src/vulkan-shaders/rms_norm.comp index cfd08d345..b554400ba 100644 --- a/src/vulkan-shaders/rms_norm.comp +++ b/src/vulkan-shaders/rms_norm.comp @@ -14,7 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; shared FLOAT_TYPE sum[BLOCK_SIZE]; void main() { - const uint row = gl_WorkGroupID.x; + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp diff --git a/src/vulkan-shaders/scale.comp b/src/vulkan-shaders/scale.comp index 510cb7237..5cd2f668d 100644 --- a/src/vulkan-shaders/scale.comp +++ b/src/vulkan-shaders/scale.comp @@ -4,9 +4,11 @@ #include "generic_unary_head.comp" void main() { - if (gl_GlobalInvocationID.x >= p.ne) { + const uint idx = get_idx(); + + if (idx >= p.ne) { return; } - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(p.param1)); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(p.param1)); } diff --git a/src/vulkan-shaders/silu.comp b/src/vulkan-shaders/silu.comp index 15920f06e..4d36f88e0 100644 --- a/src/vulkan-shaders/silu.comp +++ b/src/vulkan-shaders/silu.comp @@ -11,7 +11,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; void main() { - const uint i = gl_GlobalInvocationID.x; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; if (i >= p.KX) { return; diff --git a/src/vulkan-shaders/soft_max.comp b/src/vulkan-shaders/soft_max.comp index 1b8419c7c..0bd51ecab 100644 --- a/src/vulkan-shaders/soft_max.comp +++ b/src/vulkan-shaders/soft_max.comp @@ -28,7 +28,7 @@ shared FLOAT_TYPE vals[BLOCK_SIZE]; void main() { const uint tid = gl_LocalInvocationID.x; - const uint rowx = gl_WorkGroupID.x; + const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint rowy = rowx % p.KY; float slope = 1.0f; diff --git a/src/vulkan-shaders/square.comp b/src/vulkan-shaders/square.comp index 8dd19333d..1fa118c99 100644 --- a/src/vulkan-shaders/square.comp +++ b/src/vulkan-shaders/square.comp @@ -4,10 +4,12 @@ #include "generic_unary_head.comp" void main() { - if (gl_GlobalInvocationID.x >= p.ne) { + const uint idx = get_idx(); + + if (idx >= p.ne) { return; } - const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]); - data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val * val); + const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]); + data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val); } diff --git a/src/vulkan-shaders/sum_rows.comp b/src/vulkan-shaders/sum_rows.comp index ce2f1e2f3..961e5ffa1 100644 --- a/src/vulkan-shaders/sum_rows.comp +++ b/src/vulkan-shaders/sum_rows.comp @@ -14,7 +14,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32; shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { - const uint row = gl_WorkGroupID.x; + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint col = gl_LocalInvocationID.x; tmp[col] = FLOAT_TYPE(0.0f); diff --git a/src/vulkan-shaders/tanh.comp b/src/vulkan-shaders/tanh.comp new file mode 100644 index 000000000..74630dc7f --- /dev/null +++ b/src/vulkan-shaders/tanh.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + data_d[i] = D_TYPE(tanh(data_a[i])); +} diff --git a/src/vulkan-shaders/timestep_embedding.comp b/src/vulkan-shaders/timestep_embedding.comp new file mode 100644 index 000000000..79e065a93 --- /dev/null +++ b/src/vulkan-shaders/timestep_embedding.comp @@ -0,0 +1,41 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint nb1; + uint dim; + uint max_period; +} p; + +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 256 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_WorkGroupID.y; + const uint j = gl_GlobalInvocationID.x; + const uint d_offset = i * p.nb1; + + if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) { + data_d[d_offset + p.dim] = 0.f; + } + + const uint half_dim = p.dim / 2; + if (j >= half_dim) { + return; + } + + const float timestep = float(data_a[i]); + const float freq = float(exp(-log(p.max_period) * j / half_dim)); + const float arg = timestep * freq; + data_d[d_offset + j] = D_TYPE(cos(arg)); + data_d[d_offset + j + half_dim] = D_TYPE(sin(arg)); +} diff --git a/src/vulkan-shaders/types.comp b/src/vulkan-shaders/types.comp index d24c172ca..21dce72fc 100644 --- a/src/vulkan-shaders/types.comp +++ b/src/vulkan-shaders/types.comp @@ -6,7 +6,7 @@ #define QUANT_K 1 #define QUANT_R 1 -#ifndef LOAD_VEC_A +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 #define A_TYPE float #elif LOAD_VEC_A == 4 #define A_TYPE vec4 @@ -19,7 +19,7 @@ #define QUANT_K 1 #define QUANT_R 1 -#ifndef LOAD_VEC_A +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 #define A_TYPE float16_t #elif LOAD_VEC_A == 4 #define A_TYPE f16vec4 diff --git a/src/vulkan-shaders/upscale.comp b/src/vulkan-shaders/upscale.comp new file mode 100644 index 000000000..511a086ea --- /dev/null +++ b/src/vulkan-shaders/upscale.comp @@ -0,0 +1,36 @@ +#version 450 + +layout (push_constant) uniform parameter +{ + uint ne; uint d_offset; + uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; + float sf0; float sf1; float sf2; float sf3; +} p; + +#include "types.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i10 = idx % p.ne10; + const uint i11 = (idx / p.ne10) % p.ne11; + const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; + const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; + + const uint i00 = uint(i10 / p.sf0); + const uint i01 = uint(i11 / p.sf1); + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + + data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); +} diff --git a/src/vulkan-shaders/vulkan-shaders-gen.cpp b/src/vulkan-shaders/vulkan-shaders-gen.cpp index c5be3754b..258a1933f 100644 --- a/src/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/src/vulkan-shaders/vulkan-shaders-gen.cpp @@ -269,9 +269,12 @@ void matmul_shaders(std::vector>& tasks, bool fp16, bool matmu for (const auto& tname : type_names) { std::string data_a_key = "DATA_A_" + to_uppercase(tname); + // For unaligned, load one at a time for f32/f16, or two at a time for quants + std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2"; + // For aligned matmul loads std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2"; tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16); + string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16); })); tasks.push_back(std::async(std::launch::async, [=] { string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16); @@ -340,6 +343,9 @@ void process_shaders(std::vector>& tasks) { tasks.push_back(std::async(std::launch::async, [=] { string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); })); + tasks.push_back(std::async(std::launch::async, [=] { + string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + })); tasks.push_back(std::async(std::launch::async, [=] { string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); })); @@ -357,6 +363,9 @@ void process_shaders(std::vector>& tasks) { tasks.push_back(std::async(std::launch::async, [] { string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + })); tasks.push_back(std::async(std::launch::async, [] { string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); @@ -382,15 +391,42 @@ void process_shaders(std::vector>& tasks) { string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + })); + + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); + })); + + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + })); + tasks.push_back(std::async(std::launch::async, [] { string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + })); tasks.push_back(std::async(std::launch::async, [] { string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); })); tasks.push_back(std::async(std::launch::async, [] { string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + })); + tasks.push_back(std::async(std::launch::async, [] { + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + })); tasks.push_back(std::async(std::launch::async, [] { string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -424,6 +460,17 @@ void process_shaders(std::vector>& tasks) { tasks.push_back(std::async(std::launch::async, [=] { string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); })); + + tasks.push_back(std::async(std::launch::async, [=] { + string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + })); + tasks.push_back(std::async(std::launch::async, [=] { + string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); + })); + + tasks.push_back(std::async(std::launch::async, [=] { + string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + })); } void write_output_files() {