diff --git a/src/core/include/openvino/op/scaled_dot_product_attention.hpp b/src/core/include/openvino/op/scaled_dot_product_attention.hpp index 0ec687194dd0d4..93e55a18205ac6 100644 --- a/src/core/include/openvino/op/scaled_dot_product_attention.hpp +++ b/src/core/include/openvino/op/scaled_dot_product_attention.hpp @@ -50,6 +50,10 @@ class OPENVINO_API ScaledDotProductAttention : public Op { return m_causal; } + void set_causal(bool causal) { + m_causal = causal; + } + private: bool m_causal = false; }; diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp new file mode 100644 index 00000000000000..18c41cf2c12349 --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp @@ -0,0 +1,78 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "intel_gpu/op/sdpa.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/partial_shape.hpp" +#include "openvino/op/op.hpp" + +namespace ov { +namespace intel_gpu { +namespace op { + +class IndirectSDPA : public ov::intel_gpu::op::SDPA { +public: + OPENVINO_OP("IndirectSDPA", "gpu_opset"); + + IndirectSDPA() = default; + + IndirectSDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const ov::Output& beam_table, + const bool is_causal, + const int64_t indirect_axis, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const ov::element::Type output_type = ov::element::undefined); + + IndirectSDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const ov::Output& attn_mask, + const ov::Output& beam_table, + const bool is_causal, + const int64_t indirect_axis, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const ov::element::Type output_type = ov::element::undefined); + + IndirectSDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const ov::Output& attn_mask, + const ov::Output& scale, + const ov::Output& beam_table, + const bool is_causal, + const int64_t indirect_axis, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const ov::element::Type output_type = ov::element::undefined); + + bool visit_attributes(ov::AttributeVisitor &visitor) override; + void validate_and_infer_types() override; + + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + + ov::element::Type get_output_type() const { return m_output_type; } + + int64_t get_indirect_axis() const { return m_indirect_axis; } + + using ov::intel_gpu::op::SDPA::default_order; + +protected: + int64_t m_indirect_axis = -1; +}; + +} // namespace op +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index 7979870275d240..a20017540379af 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -285,3 +285,4 @@ REGISTER_FACTORY(internal, IndirectGemm); REGISTER_FACTORY(internal, Convolution); REGISTER_FACTORY(internal, Placeholder); REGISTER_FACTORY(internal, SDPA); +REGISTER_FACTORY(internal, IndirectSDPA); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp index f4f32a6af37d87..4cfbe21a67c7ad 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp @@ -19,6 +19,7 @@ struct scaled_dot_product_attention : public primitive_base inputs, bool is_causal, + int64_t indirect_axis = -1, const std::vector& input_q_transpose_order = {}, const std::vector& input_k_transpose_order = {}, const std::vector& input_v_transpose_order = {}, @@ -26,17 +27,23 @@ struct scaled_dot_product_attention : public primitive_base 3) - , has_scale_input(inputs.size() > 4) + , indirect_axis(indirect_axis) , input_q_transpose_order(input_q_transpose_order) , input_k_transpose_order(input_k_transpose_order) , input_v_transpose_order(input_v_transpose_order) - , output_transpose_order(output_transpose_order) {} + , output_transpose_order(output_transpose_order) { + auto data_inputs_num = inputs.size(); + if (indirect_axis != -1) + data_inputs_num--; + has_attn_mask_input = data_inputs_num > 3; + has_scale_input = data_inputs_num > 4; + } bool is_causal = false; bool has_attn_mask_input = false; bool has_scale_input = false; + int64_t indirect_axis = -1; std::vector input_q_transpose_order; std::vector input_k_transpose_order; @@ -48,6 +55,7 @@ struct scaled_dot_product_attention : public primitive_base> is_causal; ib >> has_attn_mask_input; ib >> has_scale_input; + ib >> indirect_axis; ib >> input_q_transpose_order; ib >> input_k_transpose_order; ib >> input_v_transpose_order; diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp index 992e5174b47eb9..3ec28d1e32a1df 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp @@ -129,6 +129,7 @@ class debug_configuration { std::vector forced_impl_types; // Force implementation type either ocl or onednn int max_kernels_per_batch; // Maximum number of kernels in a batch during compiling kernels int impls_cache_capacity; // The maximum number of entries in the kernel impl cache + int enable_sdpa; // Allows to control SDPA decomposition int disable_async_compilation; // Disable async compilation int disable_winograd_conv; // Disable Winograd conv int disable_dynamic_impl; // Disable dynamic implementation diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp index d60098aca74588..364c9418f10b28 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp @@ -2,34 +2,188 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "primitive_base.hpp" +#include "multi_stage_primitive.hpp" #include "scaled_dot_product_attention_inst.h" +#include "kv_cache_inst.h" + #include "sdpa/sdpa_kernel_selector.h" #include "sdpa/sdpa_kernel_base.h" namespace cldnn { namespace ocl { -struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl { - using parent = typed_primitive_impl_ocl; + +// SDPA impl may create 2 versions of the kernel internally +// 1. Default SDPA kernels +// 2. SDPA kernels with indirect access to one of the inputs +// This feature is used to avoid perf drop when we create single kernel which checks batch size in runtime +// Can be reverted once performance of the kernel is improved +struct scaled_dot_product_attention_impl : multi_stage_primitive { + using parent = multi_stage_primitive; using parent::parent; using kernel_selector_t = kernel_selector::sdpa_kernel_selector; using kernel_params_t = kernel_selector::sdpa_params; DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::scaled_dot_product_attention_impl) + const uint32_t default_sdpa = 0; + const uint32_t indirect_sdpa = 1; + std::unique_ptr clone() const override { return make_unique(*this); } + scaled_dot_product_attention_impl() = default; + + scaled_dot_product_attention_impl(const std::vector& kd) : parent(kd) { + this->can_reuse_memory = true; + } + void load(BinaryInputBuffer& ib) override { parent::load(ib); if (is_dynamic()) { auto& kernel_selector = kernel_selector_t::Instance(); - auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName); - kernel_impl->GetUpdateDispatchDataFunc(_kernel_data); + auto kernel_impl = kernel_selector.GetImplementation(_kernels_data[default_sdpa].kernelName); + kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[default_sdpa]); + if (_kernels_data.size() == 2) { + auto bt_kernel_impl = kernel_selector.GetImplementation(_kernels_data[indirect_sdpa].kernelName); + bt_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[indirect_sdpa]); + } } } +protected: + std::vector get_internal_buffer_layouts_impl() const override { + // TODO: current implementation is supposed to have the same kernel version for both indirect/default paths, + // considering this, we may assume that both indirect/default kernels have absolutely the same intermediate + // buffers number and its' sizes (since update_dispatch_data is called for both kernels too), and + // do not double memory allocations during reallocate_if_needed() function call + std::vector layouts; + if (_kernels_data.size() > 0) { + auto dtype = from_data_type(_kernels_data[0].internalBufferDataType); + const auto bpp = data_type_traits::size_of(dtype); + for (auto size : _kernels_data[0].internalBufferSizes) { + layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel) + {1, 1, 1, (tensor::value_type)(size / bpp)}}; + layouts.push_back(inbuf_layout); + } + } + + return layouts; + } + + static size_t get_beam_table_id(std::shared_ptr primitive) { + GPU_DEBUG_TRACE << "get_beam_table_id " << primitive->input_size() - 1 << "\n"; + return primitive->input_size() - 1; + } + + static bool has_indirect_inputs(const kernel_impl_params& impl_param) { + const auto& desc = impl_param.typed_desc(); + return desc->indirect_axis != -1; + } + + kernel_arguments_data get_arguments(const scaled_dot_product_attention_inst& instance, size_t stage) const override { + kernel_arguments_data args; + + auto inputs_num = instance.inputs_memory_count(); + if (instance.has_indirect_inputs() && stage == default_sdpa) + inputs_num--; + + for (size_t i = 0; i < inputs_num; i++) { + args.inputs.push_back(instance.input_memory_ptr(i)); + } + + if (instance.has_fused_primitives()) { + size_t count = instance.get_fused_mem_count(); + for (size_t i = 0; i < count; i++) { + args.fused_op_inputs.push_back(instance.fused_memory(i)); + } + } + + for (size_t i = 0; i < instance.outputs_memory_count(); i++) { + args.outputs.push_back(instance.output_memory_ptr(i)); + } + + args.shape_info = instance.shape_info_memory_ptr(); + + return args; + } + + void set_arguments_impl(scaled_dot_product_attention_inst& instance) override {} + + event::ptr execute_stage(const std::vector& events, scaled_dot_product_attention_inst& instance, size_t stage) { + stream& stream = instance.get_network().get_stream(); + std::vector tmp_events(events); + std::vector all_events; + size_t kernel_offset = 0; + + for (size_t s = 0; s < stage; s++) { + kernel_offset += _kernels_data[s].kernels.size(); + } + for (size_t kd_idx = 0; kd_idx < _kernels_data[stage].kernels.size(); ++kd_idx) { + if (_kernels_data[stage].kernels[kd_idx].skip_execution) + continue; + + size_t idx_final = kernel_offset + kd_idx; + // If any user of the desc's users is CPU implementation or network's output, set desc as a output event (event won't be nullptr) + bool needs_completion_event = instance.needs_completion_event(); + + auto& params = _kernels_data[stage].kernels[kd_idx].params; + auto args = get_arguments(instance, stage); + args.scalars = ¶ms.scalars; + + for (size_t i = 0; i < instance.get_intermediates_memories().size(); i++) + args.intermediates.push_back(instance.get_intermediates_memories()[i]); + + stream.set_arguments(*_kernels[idx_final], _kernels_data[stage].kernels[kd_idx].params, args); + + const auto& gws = params.workGroups.global; + const auto& lws = params.workGroups.local; + + GPU_DEBUG_TRACE_DETAIL << "Enqueue stage " << stage << " kernel " << idx_final << ": gws=[" << gws[0] << ", " << gws[1] << ", " << gws[2] << "] " + << "lws=[" << lws[0] << ", " << lws[1] << ", " << lws[2] << "]" + << (needs_completion_event ? " has_completion_event=true" : "") << std::endl; + + auto ev = stream.enqueue_kernel(*_kernels[idx_final], params, args, tmp_events, needs_completion_event); + if (_kernels_data[stage].needs_sub_kernels_sync) { + tmp_events = {ev}; + } + all_events.push_back(ev); + } + + return aggregate_events(all_events, stream, all_events.size() > 1); + } + + bool need_indirect_load(const scaled_dot_product_attention_inst& instance) const { + auto desc = instance.get_typed_desc(); + + if (!instance.has_indirect_inputs()) + return false; + + const auto& params = *instance.get_impl_params(); + const auto indirect_axis = desc->indirect_axis; + if (params.input_layouts[get_beam_table_id(desc)].get_partial_shape()[indirect_axis].get_length() == 1) + return false; + + const auto& deps = instance.dependencies(); + + const auto indirect_dep_idx = 1; + const auto& indirect_dep = deps[indirect_dep_idx].first; + if (dynamic_cast(indirect_dep) == nullptr) { + return true; + } + + auto state_layout = indirect_dep->get_impl_params()->get_input_layout(0); + bool is_prefill = state_layout.count() == 0; + return !is_prefill; + } + + event::ptr execute_impl(const std::vector& events, scaled_dot_product_attention_inst& instance) override { + if (need_indirect_load(instance)) + return execute_stage(events, instance, indirect_sdpa); + else + return execute_stage(events, instance, default_sdpa); + } + static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) { kernel_selector::sdpa_configuration config; @@ -44,16 +198,16 @@ struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl(); - const auto query_shape = transpose_pshape(impl_param.get_input_layout(0).get_partial_shape(), prim->input_q_transpose_order); - const auto key_shape = transpose_pshape(impl_param.get_input_layout(1).get_partial_shape(), prim->input_k_transpose_order); - const auto value_shape = transpose_pshape(impl_param.get_input_layout(2).get_partial_shape(), prim->input_v_transpose_order); + const auto& desc = impl_param.typed_desc(); + const auto query_shape = transpose_pshape(impl_param.get_input_layout(0).get_partial_shape(), desc->input_q_transpose_order); + const auto key_shape = transpose_pshape(impl_param.get_input_layout(1).get_partial_shape(), desc->input_k_transpose_order); + const auto value_shape = transpose_pshape(impl_param.get_input_layout(2).get_partial_shape(), desc->input_v_transpose_order); OPENVINO_ASSERT(key_shape == value_shape, "[GPU] The shapes of key and value inputs are expected to be equal"); for (size_t i = 0; i < query_shape.size(); ++i) { if (query_shape[i].is_static() && key_shape[i].is_static() && value_shape[i].is_static()) { if (query_shape[i].get_length() > key_shape[i].get_length()) { - config.broadcast_axis = prim->input_k_transpose_order[i]; + config.broadcast_axis = desc->input_k_transpose_order[i]; config.group_size = query_shape[i].get_length() / key_shape[i].get_length(); } } @@ -62,44 +216,73 @@ struct scaled_dot_product_attention_impl : typed_primitive_impl_oclis_causal; + config.is_causal = desc->is_causal; return config; } - static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic) { +public: + static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic, bool indirect = false) { + const auto& desc = impl_param.typed_desc(); auto params = get_default_params(impl_param, is_dynamic); - const auto inputs_num = impl_param.input_layouts.size(); - params.inputs.resize(inputs_num); - for (size_t i = 0; i < inputs_num; i++) { + auto data_inputs_num = impl_param.input_layouts.size(); + if (has_indirect_inputs(impl_param)) + data_inputs_num--; + + params.inputs.resize(data_inputs_num); + for (size_t i = 0; i < data_inputs_num; i++) { params.inputs[i] = convert_data_tensor(impl_param.get_input_layout(i)); } params.conf = get_sdpa_configuration(impl_param); - const auto& prim = impl_param.typed_desc(); - params.input0_order = prim->input_q_transpose_order; - params.input1_order = prim->input_k_transpose_order; - params.input2_order = prim->input_v_transpose_order; - params.output_order = prim->output_transpose_order; + params.input0_order = desc->input_q_transpose_order; + params.input1_order = desc->input_k_transpose_order; + params.input2_order = desc->input_v_transpose_order; + params.output_order = desc->output_transpose_order; + + if (indirect && has_indirect_inputs(impl_param)) { + params.beam_table = convert_data_tensor(impl_param.get_input_layout(get_beam_table_id(desc))); + params.indirect_axis = desc->indirect_axis; + } params.set_dynamic_shape_offsets(); + // Need to adjust sdpa kernel offset to consider beam table input + if (has_indirect_inputs(impl_param)) { + auto out_offset = params.outputs[0].get_dynamic_shape_offset(); + if (indirect) + params.beam_table.SetDynamicShapeOffset(out_offset); + + params.outputs[0].SetDynamicShapeOffset(out_offset + kernel_selector::DataTensor::max_rank()); + } + return params; } static std::unique_ptr create(const typed_program_node& arg, const kernel_impl_params& impl_param) { + std::vector kernels_data; auto sdpa_kernel_params = get_kernel_params(impl_param, impl_param.is_dynamic()); - auto& sdpa_kernel_selector = kernel_selector_t::Instance(); - auto kd = sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params); + auto& kernel_selector = kernel_selector_t::Instance(); + kernels_data.push_back(kernel_selector.get_best_kernel(sdpa_kernel_params)); - return cldnn::make_unique(kd); + if (has_indirect_inputs(impl_param)) { + auto indirect_kernel_params = get_kernel_params(impl_param, impl_param.is_dynamic(), true); + kernels_data.push_back(kernel_selector.get_best_kernel(indirect_kernel_params)); + } + + return cldnn::make_unique(kernels_data); } void update_dispatch_data(const kernel_impl_params& impl_param) override { auto kernel_params = get_kernel_params(impl_param, true); - (_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data); + (_kernels_data[default_sdpa].update_dispatch_data_func)(kernel_params, _kernels_data[default_sdpa]); + + if (_kernels_data.size() == 2) { + auto kernel_params = get_kernel_params(impl_param, true); + (_kernels_data[indirect_sdpa].update_dispatch_data_func)(kernel_params, _kernels_data[indirect_sdpa]); + } } }; diff --git a/src/plugins/intel_gpu/src/graph/include/scaled_dot_product_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/scaled_dot_product_attention_inst.h index cecb2a0f609550..ef75fd2f31e20f 100644 --- a/src/plugins/intel_gpu/src/graph/include/scaled_dot_product_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/scaled_dot_product_attention_inst.h @@ -32,6 +32,9 @@ class typed_primitive_inst : public typed_primitiv static std::vector calc_output_layouts(scaled_dot_product_attention_node const& /*node*/, const kernel_impl_params& impl_param); static layout calc_output_layout(scaled_dot_product_attention_node const& node, kernel_impl_params const& impl_param); static std::string to_string(scaled_dot_product_attention_node const& node); + bool has_indirect_inputs() const { + return get_typed_desc()->indirect_axis != -1; + } typed_primitive_inst(network& network, scaled_dot_product_attention_node const& desc); }; diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index 14cef4010c6bea..0f355fff1afd67 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -96,6 +96,24 @@ inline uint FUNC(get_input2_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint #endif } +#ifdef BEAM_TABLE_TYPE +inline uint FUNC(get_bt_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#if BEAM_TABLE_SIMPLE + return GET_DATA_INDEX_6D_SAFE(BEAM_TABLE, b, f, w, z, y, x); +#else +# error sdpa_ref.cl : Unsupported beam table format +#endif +} + +inline uint FUNC(get_bt_index_key)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { + return FUNC_CALL(get_bt_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR INPUT1_DIMS_ORDER); +} + +inline uint FUNC(get_bt_index_value)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { + return FUNC_CALL(get_bt_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR INPUT2_DIMS_ORDER); +} +#endif + #define VALUE_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT2_TYPE, 1, ptr, offset) #define SUBGROUPS_PER_WG (HEAD_SIZE / SUBGROUP_SIZE) @@ -117,6 +135,9 @@ KERNEL(sdpa_opt)( const __global INPUT4_TYPE* scale, #endif __global OUTPUT_TYPE* output, +#ifdef BEAM_TABLE_TYPE + const __global BEAM_TABLE_TYPE* beam_table, +#endif __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, __global OUTPUT_TYPE* tmp_out @@ -125,12 +146,7 @@ KERNEL(sdpa_opt)( const uint batch_idx = get_global_id(0); const uint b0_idx = batch_idx / NUM_HEADS; /* BATCH dim */ const uint b1_idx = batch_idx % NUM_HEADS; /* HEADS_NUM dim */ - -#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 - const uint target_seq_idx = (uint)get_global_id(1) * TARGET_SEQ_LEN_BLOCK_SIZE; -#else const uint target_seq_idx = get_global_id(1); -#endif const uint lid = get_local_id(2); const uint head_size_idx = lid; @@ -173,12 +189,7 @@ KERNEL(sdpa_opt)( // Query input loading to SLM #define QUERY_STEP_LOCAL SUBGROUP_SIZE * SUBGROUPS_PER_WG uint query_local_offset = sgid * SUBGROUP_SIZE + sglid; - -#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 - const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); -#else const uint seq_idx_end = 1; -#endif #ifdef INPUT0_DIMS_ORDER uint query_offset = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, target_seq_idx, (sgid * SUBGROUP_SIZE)); uint query_offset_next_seq = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, target_seq_idx + 1, (sgid * SUBGROUP_SIZE)); @@ -207,9 +218,14 @@ KERNEL(sdpa_opt)( // HEAD_SIZE / SUBGROUPS_PER_WG times in the loop and saves the result to the qk_local SLM buffer for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += (HEAD_SIZE / SUBGROUP_SIZE)) { #ifdef INPUT1_DIMS_ORDER - uint key_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len, 0); +#ifdef BEAM_TABLE_TYPE + const uint b_idx = beam_table[FUNC_CALL(get_bt_index_key)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len, 0)]; #else - uint key_offset = INPUT1_GET_INDEX(b0_idx, b1_idx, start_partition_idx + seq_len, 0); + const uint b_idx = b0_idx; +#endif + const uint key_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b_idx, b1_idx, 0, 0, start_partition_idx + seq_len, 0); +#else + const uint key_offset = INPUT1_GET_INDEX(b0_idx, b1_idx, start_partition_idx + seq_len, 0); #endif INPUT0_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {INPUT0_VAL_ZERO}; @@ -316,11 +332,7 @@ KERNEL(sdpa_opt)( barrier(CLK_LOCAL_MEM_FENCE); INPUT0_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; -#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 - const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); -#else const uint seq_idx_end = 1; -#endif for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { // Iterate over all values QK values in SLM and apply scale and attention mask for (uint seq_len = sgid * SUBGROUP_SIZE + sglid; seq_len < partition_seq_len; seq_len += (HEAD_SIZE)) { @@ -349,11 +361,7 @@ KERNEL(sdpa_opt)( { // SoftMax calculation -#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 - const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); -#else const uint seq_idx_end = 1; -#endif // Find the maximum value of qk in the subgroup for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { qk_max[seq_idx] = sub_group_reduce_max(qk_max[seq_idx]); @@ -446,20 +454,26 @@ KERNEL(sdpa_opt)( { // Gemm2 calculation OUTPUT_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {OUTPUT_VAL_ZERO}; - +#ifndef BEAM_TABLE_TYPE #ifdef INPUT2_DIMS_ORDER uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 0, 0); uint value_offset_next_seq = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 1, 0); const uint value_pitch = value_offset_next_seq - value_offset; #else const uint value_pitch = HEAD_SIZE; +#endif #endif for (uint seq_len = 0; seq_len < partition_seq_len / SUBGROUP_SIZE; seq_len++) { +#ifdef BEAM_TABLE_TYPE + uint b_idx = beam_table[FUNC_CALL(get_bt_index_value)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE) + sglid, (head_size_idx / SUBGROUP_SIZE) * SUBGROUP_SIZE)]; + uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE) + sglid, (head_size_idx / SUBGROUP_SIZE) * SUBGROUP_SIZE); +#else #ifdef INPUT2_DIMS_ORDER uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); #else uint value_offset = INPUT2_GET_INDEX(b0_idx, b1_idx, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); +#endif #endif OUTPUT_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; @@ -468,19 +482,30 @@ KERNEL(sdpa_opt)( } unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) { +#ifdef BEAM_TABLE_TYPE + INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, sub_group_broadcast(value_offset, i)); +#else INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, value_offset); +#endif unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { acc[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], i), value_val, acc[seq_idx]); } +#ifndef BEAM_TABLE_TYPE value_offset += value_pitch; +#endif } } const uint seq_len_leftovers_start = (partition_seq_len / SUBGROUP_SIZE) * SUBGROUP_SIZE; for (uint seq_len = seq_len_leftovers_start; seq_len < partition_seq_len; seq_len++) { #ifdef INPUT2_DIMS_ORDER - const uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len, head_size_idx); +#ifdef BEAM_TABLE_TYPE + const uint b_idx = beam_table[FUNC_CALL(get_bt_index_value)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len, head_size_idx)]; +#else + const uint b_idx = b0_idx; +#endif + const uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b_idx, b1_idx, 0, 0, start_partition_idx + seq_len, head_size_idx); #else const uint value_offset = INPUT2_GET_INDEX(b0_idx, b1_idx, start_partition_idx + seq_len, head_size_idx); #endif @@ -500,11 +525,7 @@ KERNEL(sdpa_opt)( // If the number of partitions is greater than 1, save results to the temporary buffer; // otherwise, save results directly to the main output. if (num_of_partitions > 1) { -#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 - const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); -#else const uint seq_idx_end = 1; -#endif for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { // Data layout of tmp_output buf: [batch, heads_num, q_len, partition_idx, head_size] const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + @@ -515,15 +536,11 @@ KERNEL(sdpa_opt)( tmp_out[tmp_out_offset] = acc[seq_idx]; } } else { -#if TARGET_SEQ_LEN_BLOCK_SIZE > 1 - const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); -#else const uint seq_idx_end = 1; -#endif for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { - const uint output_offset = OUTPUT_GET_INDEX(b0_idx, b1_idx, target_seq_idx + seq_idx, head_size_idx); + const uint output_offset = OUTPUT_GET_INDEX(b0_idx, b1_idx, target_seq_idx + seq_idx, head_size_idx); - output[output_offset] = acc[seq_idx]; + output[output_offset] = acc[seq_idx]; } } } // Gemm2 calculation end @@ -545,6 +562,9 @@ KERNEL(sdpa_opt)( const __global INPUT4_TYPE* scale, #endif __global OUTPUT_TYPE* output, +#ifdef BEAM_TABLE_TYPE + const __global BEAM_TABLE_TYPE* beam_table, +#endif __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, __global OUTPUT_TYPE* tmp_out @@ -637,6 +657,10 @@ KERNEL(sdpa_opt)( // Main Gemm1 calculation loop uint seq_len = sgid * TARGET_SEQ_LEN_BLOCK_SIZE; for (; seq_len < partition_seq_len; seq_len += SUBGROUPS_PER_WG * SUBGROUP_SIZE) { +#ifdef BEAM_TABLE_TYPE + const uint b_idx = beam_table[FUNC_CALL(get_bt_index_key)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len + sglid, 0)]; + const uint key_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b_idx, b1_idx, 0, 0, start_partition_idx + seq_len + sglid, 0); +#else #ifdef INPUT1_DIMS_ORDER uint key_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len, 0); uint key_offset_next_seq = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len + 1, 0); @@ -644,6 +668,7 @@ KERNEL(sdpa_opt)( #else uint key_offset = INPUT1_GET_INDEX(b0_idx, b1_idx, start_partition_idx + seq_len, 0); const uint key_pitch = HEAD_SIZE; +#endif #endif INPUT0_TYPE acc[TARGET_SEQ_LEN_BLOCK_SIZE] = {INPUT0_VAL_ZERO}; @@ -660,7 +685,11 @@ KERNEL(sdpa_opt)( } unroll_for (uint key_row_idx = 0; key_row_idx < TARGET_SEQ_LEN_BLOCK_SIZE; key_row_idx++) { +#ifdef BEAM_TABLE_TYPE + INPUT1_TYPE key_vals = KEY_BLOCK_READ(key_input, sub_group_broadcast(key_offset, key_row_idx) + head_idx_index); +#else INPUT1_TYPE key_vals = KEY_BLOCK_READ(key_input, key_offset + key_row_idx * key_pitch + head_idx_index); +#endif unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) { acc[key_row_idx] = mad(sub_group_broadcast(key_vals, i), queries_vec[i], acc[key_row_idx]); @@ -801,10 +830,15 @@ KERNEL(sdpa_opt)( #endif for (uint seq_len = 0; seq_len < partition_seq_len / SUBGROUP_SIZE; seq_len++) { +#ifdef BEAM_TABLE_TYPE + const uint b_idx = beam_table[FUNC_CALL(get_bt_index_value)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE) + sglid, sgid * SUBGROUP_SIZE)]; + const uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE) + sglid, sgid * SUBGROUP_SIZE); +#else #ifdef INPUT2_DIMS_ORDER uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); #else uint value_offset = INPUT2_GET_INDEX(b0_idx, b1_idx, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); +#endif #endif OUTPUT_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; @@ -813,12 +847,18 @@ KERNEL(sdpa_opt)( } unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) { +#ifdef BEAM_TABLE_TYPE + INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, sub_group_broadcast(value_offset, i)); +#else INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, value_offset); +#endif unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { acc[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], i), value_val, acc[seq_idx]); } +#ifndef BEAM_TABLE_TYPE value_offset += value_pitch; +#endif } } @@ -833,17 +873,31 @@ KERNEL(sdpa_opt)( qk_val[seq_idx] = qk_local[qk_offset]; qk_offset += SEQ_LEN_PARTITION_SIZE; } - +#ifdef BEAM_TABLE_TYPE + const uint b_idx = beam_table[FUNC_CALL(get_bt_index_value)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len_leftovers_start + sglid, sgid * SUBGROUP_SIZE)]; + const uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b_idx, b1_idx, 0, 0, start_partition_idx + seq_len_leftovers_start + sglid, sgid * SUBGROUP_SIZE); +#else +#ifdef INPUT2_DIMS_ORDER uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len_leftovers_start, head_size_idx); +#else + uint value_offset = INPUT2_GET_INDEX(b0_idx, b1_idx, start_partition_idx + seq_len_leftovers_start, head_size_idx); +#endif +#endif for (uint seq_len_idx = 0; seq_len_idx < partition_seq_len - seq_len_leftovers_start; seq_len_idx++) { +#ifdef BEAM_TABLE_TYPE + INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, sub_group_broadcast(value_offset, seq_len_idx)); +#else INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, value_offset); +#endif for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { acc[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], seq_len_idx), value_val, acc[seq_idx]); } +#ifndef BEAM_TABLE_TYPE value_offset += value_pitch; +#endif } } @@ -890,10 +944,15 @@ KERNEL(sdpa_opt)( #endif for (uint seq_len = 0; seq_len < partition_seq_len / SUBGROUP_SIZE; seq_len++) { +#ifdef BEAM_TABLE_TYPE + const uint b_idx = beam_table[FUNC_CALL(get_bt_index_value)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE) + sglid, sgid * SUBGROUP_SIZE)]; + const uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE) + sglid, sgid * SUBGROUP_SIZE); +#else #ifdef INPUT2_DIMS_ORDER uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); #else uint value_offset = INPUT2_GET_INDEX(b0_idx, b1_idx, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx); +#endif #endif OUTPUT_TYPE qk_val[TARGET_SEQ_LEN_BLOCK_SIZE]; @@ -902,12 +961,18 @@ KERNEL(sdpa_opt)( } unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) { +#ifdef BEAM_TABLE_TYPE + INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, sub_group_broadcast(value_offset, i)); +#else INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, value_offset); +#endif unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { acc[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], i), value_val, acc[seq_idx]); } +#ifndef BEAM_TABLE_TYPE value_offset += value_pitch; +#endif } } diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_ref.cl index cd289be026e7e3..83e3c7c7e9fef1 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_ref.cl @@ -93,6 +93,24 @@ inline uint FUNC(get_input2_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint #endif } +#ifdef BEAM_TABLE_TYPE +inline uint FUNC(get_bt_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#if BEAM_TABLE_SIMPLE + return GET_DATA_INDEX_6D_SAFE(BEAM_TABLE, b, f, w, z, y, x); +#else +# error sdpa_ref.cl : Unsupported beam table format +#endif +} + +inline uint FUNC(get_bt_index_key)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { + return FUNC_CALL(get_bt_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR INPUT1_DIMS_ORDER); +} + +inline uint FUNC(get_bt_index_value)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { + return FUNC_CALL(get_bt_index_nt)(OPTIONAL_SHAPE_INFO_TENSOR INPUT2_DIMS_ORDER); +} +#endif + #define APPLY_SCALE_TO_QUERY 1 KERNEL(sdpa_ref)( @@ -107,6 +125,9 @@ KERNEL(sdpa_ref)( const __global INPUT4_TYPE* scale, #endif __global OUTPUT_TYPE* output, +#ifdef BEAM_TABLE_TYPE + const __global BEAM_TABLE_TYPE* beam_table, +#endif __global OUTPUT_TYPE* tmp_buf ) { @@ -129,7 +150,12 @@ KERNEL(sdpa_ref)( OUTPUT_TYPE acc = 0; for (uint h = 0; h < HEAD_SIZE /* head_size */; h++) { uint query_offset = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0, b1, 0, 0, target_seq_idx, h); - uint key_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b0, b1, 0, 0, s, h); +#ifdef BEAM_TABLE_TYPE + uint b_idx = beam_table[FUNC_CALL(get_bt_index_key)(OPTIONAL_SHAPE_INFO_TENSOR b0, b1, 0, 0, s, h)]; +#else + uint b_idx = b0; +#endif + uint key_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b_idx, b1, 0, 0, s, h); #if APPLY_SCALE_TO_QUERY INPUT0_TYPE q_val = query_input[query_offset] * scale_val; @@ -202,7 +228,13 @@ KERNEL(sdpa_ref)( uint tmp_buf_offset = b0 * (NUM_HEADS * TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + b1 * (TARGET_SEQ_LEN * SOURCE_SEQ_LEN) + target_seq_idx * (SOURCE_SEQ_LEN) + s; - uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0, b1, 0, 0, s, head_size_idx); + +#ifdef BEAM_TABLE_TYPE + uint b_idx = beam_table[FUNC_CALL(get_bt_index_value)(OPTIONAL_SHAPE_INFO_TENSOR b0, b1, 0, 0, s, head_size_idx)]; +#else + uint b_idx = b0; +#endif + uint value_offset = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b_idx, b1, 0, 0, s, head_size_idx); acc += tmp_buf[tmp_buf_offset] * value_input[value_offset]; } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp index 61028ef5348a1a..4871af9ed591fe 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp @@ -85,15 +85,30 @@ JitConstants SDPAKernelBase::GetJitConstants(const sdpa_params& params) const { return true; }; - if ((!params.input0_order.empty() && !is_default_order(params.input0_order)) || params.conf.broadcast_axis != -1) { + auto use_index_calc_func = [&](const std::vector order, bool is_query = false) { + if (!params.input0_order.empty() && !is_default_order(params.input0_order)) + return true; + + if (params.conf.broadcast_axis != -1) + return true; + + if (params.indirect_axis != -1 && !is_query) + return true; + + return false; + }; + + if (params.indirect_axis != -1) + jit.AddConstant(MakeJitConstant("BEAM_TABLE", params.beam_table)); + + if (use_index_calc_func(params.input0_order, true)) jit.AddConstant(MakeJitConstant("INPUT0_DIMS_ORDER", GetDimsOrder(params.input0_order))); - } - if ((!params.input1_order.empty() && !is_default_order(params.input1_order)) || params.conf.broadcast_axis != -1) { + + if (use_index_calc_func(params.input1_order)) jit.AddConstant(MakeJitConstant("INPUT1_DIMS_ORDER", GetDimsOrder(params.input1_order))); - } - if ((!params.input2_order.empty() && !is_default_order(params.input2_order)) || params.conf.broadcast_axis != -1) { + + if (use_index_calc_func(params.input2_order)) jit.AddConstant(MakeJitConstant("INPUT2_DIMS_ORDER", GetDimsOrder(params.input2_order))); - } TransposedDimensionAccessHelperJit dims_q(params.inputs[0], params.input0_order); jit.AddConstant(MakeJitConstant("TARGET_SEQ_LEN", dims_q.y())); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h index 1d4f30512df06b..215f19ecc881c9 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h @@ -99,6 +99,9 @@ struct sdpa_params : public base_params { std::vector input1_order; std::vector input2_order; std::vector output_order; + int64_t indirect_axis = -1; + + DataTensor beam_table; sdpa_configuration conf; }; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp index 581565874f7fbb..359e8696cbef19 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp @@ -29,6 +29,22 @@ static size_t get_seq_len_partition_size() { return seq_len; } +static std::string GetKernelName(std::string base_name, KernelsTypes type, bool is_indirect) { + auto kernel_name = base_name; + if (is_indirect) + kernel_name += "_ind"; + + if (type == KernelsTypes::SINGLE_TOKEN) { + kernel_name += "_single_token"; + } else if (type == KernelsTypes::MULTI_TOKENS) { + kernel_name += "_multi_tokens"; + } else if (type == KernelsTypes::FINALIZATION) { + kernel_name += "_finalization"; + } + + return kernel_name; +} + ParamsKey SDPAKernelOpt::GetSupportedKey() const { ParamsKey k; k.EnableInputDataType(Datatype::F16); @@ -104,7 +120,7 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k CeilDiv(target_seq_len, target_seq_len_block_size), head_size * num_of_partitions }; dispatch_data.lws = { 1, 1, head_size }; - } else if (kernel_idx == 2) { + } else if (kernel_idx == KernelsTypes::FINALIZATION) { dispatch_data.gws = { batch_size * heads_num, target_seq_len, 16 }; @@ -134,8 +150,7 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const { const auto& prim_params = dynamic_cast(params); for (size_t kernel_idx = 0; kernel_idx < kernels_num; kernel_idx++) { auto dispatch_data = SetDefault(prim_params, kernel_idx); - auto kernel_name = kernel_idx == 0 ? kernelName + "_single_token" : - kernel_idx == 1 ? kernelName + "_multi_tokens" : kernelName + "_finalization"; + auto kernel_name = GetKernelName(kernelName, static_cast(kernel_idx), prim_params.indirect_axis != -1); auto entry_point = GetEntryPoint(kernel_name, prim_params.layerID, params); auto jit_constants = GetJitConstants(prim_params, kernel_idx); auto jit = CreateJit(kernel_name, jit_constants, entry_point); @@ -171,6 +186,9 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const { auto tmp_out_elements_count = (num_of_partitions == 1) ? 1 : output.LogicalSize() * num_of_partitions; auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size; + if (prim_params.indirect_axis != -1 && kernel_idx != KernelsTypes::FINALIZATION) + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, static_cast(prim_params.inputs.size())}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp index a80f3c31dfc8f3..579c4bc06c17e2 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp @@ -13,6 +13,8 @@ ParamsKey SDPAKernelRef::GetSupportedKey() const { ParamsKey k; k.EnableInputDataType(Datatype::F16); k.EnableInputDataType(Datatype::F32); + // beam table input + k.EnableInputDataType(Datatype::INT32); k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F32); @@ -72,6 +74,9 @@ KernelsData SDPAKernelRef::GetKernelsData(const Params& params) const { "", false, false, static_cast(prim_params.inputs.size()), GetFusedPrimitiveInputsCount(params), 1, prim_params.is_shape_agnostic); + if (prim_params.indirect_axis != -1) + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, static_cast(prim_params.inputs.size())}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); kd.internalBufferSizes.clear(); diff --git a/src/plugins/intel_gpu/src/plugin/ops/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/scaled_dot_product_attention.cpp index c07c501a1f970b..d002c868ffd225 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/scaled_dot_product_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/scaled_dot_product_attention.cpp @@ -6,6 +6,7 @@ #include "intel_gpu/plugin/common_utils.hpp" #include "intel_gpu/op/sdpa.hpp" +#include "intel_gpu/op/indirect_sdpa.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" @@ -15,6 +16,7 @@ namespace ov { namespace op { namespace internal { using SDPA = ov::intel_gpu::op::SDPA; +using IndirectSDPA = ov::intel_gpu::op::IndirectSDPA; } // namespace internal } // namespace op } // namespace ov @@ -41,9 +43,30 @@ static void CreateSDPAOp(ProgramBuilder& p, const std::shared_ptrget_causal(); + int64_t indirect_axis = -1; auto sdpa_prim = cldnn::scaled_dot_product_attention(layerName, inputs, is_causal, + indirect_axis, + op->get_input0_transpose_order(), + op->get_input1_transpose_order(), + op->get_input2_transpose_order(), + op->get_output_transpose_order()); + + p.add_primitive(*op, sdpa_prim); +} + +static void CreateIndirectSDPAOp(ProgramBuilder& p, const std::shared_ptr& op) { + validate_inputs_count(op, {4, 5, 6}); + auto inputs = p.GetInputInfo(op); + auto layerName = layer_type_name_ID(op); + + bool is_causal = op->get_causal(); + int64_t indirect_axis = op->get_indirect_axis(); + auto sdpa_prim = cldnn::scaled_dot_product_attention(layerName, + inputs, + is_causal, + indirect_axis, op->get_input0_transpose_order(), op->get_input1_transpose_order(), op->get_input2_transpose_order(), @@ -53,6 +76,7 @@ static void CreateSDPAOp(ProgramBuilder& p, const std::shared_ptr Plugin::get_supported_properties() const { ov::PropertyName{ov::intel_gpu::hint::host_task_priority.name(), PropertyMutability::RW}, ov::PropertyName{ov::intel_gpu::hint::queue_priority.name(), PropertyMutability::RW}, ov::PropertyName{ov::intel_gpu::hint::queue_throttle.name(), PropertyMutability::RW}, + ov::PropertyName{ov::intel_gpu::hint::enable_sdpa_optimization.name(), PropertyMutability::RW}, ov::PropertyName{ov::intel_gpu::enable_loop_unrolling.name(), PropertyMutability::RW}, ov::PropertyName{ov::intel_gpu::disable_winograd_convolution.name(), PropertyMutability::RW}, ov::PropertyName{ov::cache_dir.name(), PropertyMutability::RW}, diff --git a/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp index d612ad03886f19..e6f58aaa25fb15 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp @@ -6,7 +6,9 @@ #include #include "intel_gpu/op/gemm.hpp" +#include "intel_gpu/op/sdpa.hpp" #include "intel_gpu/op/indirect_gemm.hpp" +#include "intel_gpu/op/indirect_sdpa.hpp" #include "intel_gpu/op/kv_cache.hpp" #include "intel_gpu/op/read_value.hpp" #include "intel_gpu/plugin/common_utils.hpp" @@ -42,7 +44,7 @@ void replace_node_unsafe(const std::shared_ptr& target, const std::sha namespace ov { namespace intel_gpu { -IndirectKVCache::IndirectKVCache() { +IndirectGemmOpt::IndirectGemmOpt() { using namespace ov::pass::pattern; auto beam_idx = wrap_type(); @@ -108,9 +110,141 @@ IndirectKVCache::IndirectKVCache() { return true; }; - auto m = std::make_shared(matmul, "IndirectKVCache"); + auto m = std::make_shared(matmul, "IndirectGemmOpt"); this->register_matcher(m, callback); } +IndirectSDPAOpt::IndirectSDPAOpt() { + using namespace ov::pass::pattern; + using ov::pass::pattern::op::Or; + + auto beam_idx = wrap_type(); + auto gather_input_0 = wrap_type(); + auto gather_input_1 = wrap_type(); + auto axis_const = wrap_type( + ov::op::util::constant_predicate([](const std::vector& value) -> bool { + return value.size() == 1 && (value[0] == 0 || value[0] == 1); + })); + auto gather_past_0 = wrap_type({gather_input_0, beam_idx, axis_const}); + auto gather_past_1 = wrap_type({gather_input_1, beam_idx, axis_const}); + auto kv_cache_0 = wrap_type({gather_past_0, any_input()}); + auto kv_cache_1 = wrap_type({gather_past_1, any_input()}); + + auto input_attn_mask = any_input(); + auto input_scale = any_input(); + auto sdpa_without_attn_mask_m = wrap_type({ any_input(), kv_cache_0, kv_cache_1 }); + auto sdpa_with_attn_mask_m = wrap_type({ any_input(), kv_cache_0, kv_cache_1, input_attn_mask }); + auto sdpa_with_attn_mask_and_scale_m = + wrap_type({ any_input(), kv_cache_0, kv_cache_1, input_attn_mask, input_scale }); + + auto sdpa_m = std::make_shared(OutputVector{sdpa_without_attn_mask_m, sdpa_with_attn_mask_m, sdpa_with_attn_mask_and_scale_m}); + + ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { + if (transformation_callback(m.get_match_root())) { + return false; + } + const auto& pattern_map = m.get_pattern_value_map(); + + auto kv_cache_node_0 = std::dynamic_pointer_cast(pattern_map.at(kv_cache_0).get_node_shared_ptr()); + auto kv_cache_node_1 = std::dynamic_pointer_cast(pattern_map.at(kv_cache_1).get_node_shared_ptr()); + + auto beam_idx_node = pattern_map.at(beam_idx).get_node_shared_ptr(); + auto gather_input_node_0 = pattern_map.at(gather_input_0).get_node_shared_ptr(); + auto gather_input_node_1 = pattern_map.at(gather_input_1).get_node_shared_ptr(); + auto gather_node_0 = std::dynamic_pointer_cast(pattern_map.at(gather_past_0).get_node_shared_ptr()); + auto gather_node_1 = std::dynamic_pointer_cast(pattern_map.at(gather_past_1).get_node_shared_ptr()); + auto gather_axis_0 = gather_node_0->get_axis(); + auto gather_axis_1 = gather_node_1->get_axis(); + OPENVINO_ASSERT(gather_axis_0 == gather_axis_1); + + ov::replace_node(gather_node_0, gather_input_node_0); + ov::replace_node(gather_node_1, gather_input_node_1); + + auto indirect_kv_cache_0 = std::make_shared(gather_input_node_0, + kv_cache_node_0->get_input_node_shared_ptr(1), + beam_idx_node, + kv_cache_node_0->get_variable(), + kv_cache_node_0->get_concat_axis(), + gather_axis_0, + kv_cache_node_0->get_output_element_type(0)); + + auto indirect_kv_cache_1 = std::make_shared(gather_input_node_1, + kv_cache_node_1->get_input_node_shared_ptr(1), + beam_idx_node, + kv_cache_node_1->get_variable(), + kv_cache_node_1->get_concat_axis(), + gather_axis_1, + kv_cache_node_1->get_output_element_type(0)); + + indirect_kv_cache_0->set_friendly_name(kv_cache_node_0->get_friendly_name()); + indirect_kv_cache_1->set_friendly_name(kv_cache_node_1->get_friendly_name()); + ov::copy_runtime_info(kv_cache_node_0, indirect_kv_cache_0); + ov::copy_runtime_info(kv_cache_node_1, indirect_kv_cache_1); + replace_node_unsafe(kv_cache_node_0, indirect_kv_cache_0); + replace_node_unsafe(kv_cache_node_1, indirect_kv_cache_1); + + auto sdpa = std::dynamic_pointer_cast(m.get_match_root()); + auto order_in0 = sdpa->get_input0_transpose_order(); + auto order_in1 = sdpa->get_input1_transpose_order(); + auto order_in2 = sdpa->get_input2_transpose_order(); + auto order_out = sdpa->get_output_transpose_order(); + auto is_causal = sdpa->get_causal(); + + std::shared_ptr indirect_sdpa; + if (pattern_map.find(sdpa_without_attn_mask_m) != pattern_map.end()) { + indirect_sdpa = std::make_shared(sdpa->get_input_node_shared_ptr(0), + sdpa->get_input_node_shared_ptr(1), + sdpa->get_input_node_shared_ptr(2), + indirect_kv_cache_0->output(1), // beam table + is_causal, + gather_axis_1, + order_in0, + order_in1, + order_in2, + order_out); + } else if (pattern_map.find(sdpa_with_attn_mask_m) != pattern_map.end()) { + indirect_sdpa = std::make_shared(sdpa->get_input_node_shared_ptr(0), + sdpa->get_input_node_shared_ptr(1), + sdpa->get_input_node_shared_ptr(2), + sdpa->get_input_node_shared_ptr(3), + indirect_kv_cache_0->output(1), // beam table + is_causal, + gather_axis_1, + order_in0, + order_in1, + order_in2, + order_out); + } else if (pattern_map.find(sdpa_with_attn_mask_and_scale_m) != pattern_map.end()) { + indirect_sdpa = std::make_shared(sdpa->get_input_node_shared_ptr(0), + sdpa->get_input_node_shared_ptr(1), + sdpa->get_input_node_shared_ptr(2), + sdpa->get_input_node_shared_ptr(3), + sdpa->get_input_node_shared_ptr(4), + indirect_kv_cache_0->output(1), // beam table + is_causal, + gather_axis_1, + order_in0, + order_in1, + order_in2, + order_out); + } + + OPENVINO_ASSERT(indirect_sdpa != nullptr); + + indirect_sdpa->set_friendly_name(sdpa->get_friendly_name()); + ov::copy_runtime_info(sdpa, indirect_sdpa); + ov::replace_node(sdpa, indirect_sdpa); + + return true; + }; + + auto m = std::make_shared(sdpa_m, "IndirectSDPAOpt"); + this->register_matcher(m, callback); +} + +IndirectKVCache::IndirectKVCache() { + add_matcher(); + add_matcher(); +} } // namespace intel_gpu } // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.hpp b/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.hpp index afea5da6ceb13c..2a6c4a347f9217 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.hpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.hpp @@ -36,11 +36,22 @@ namespace intel_gpu { /// ┌────┴──────┐ ┌────┴──────┴───┐ /// │ Gemm │ | IndirectGemm | /// └───────────┘ └───────────────┘ -class IndirectKVCache : public ov::pass::MatcherPass { +class IndirectKVCache : public ov::pass::GraphRewrite { public: OPENVINO_RTTI("IndirectKVCache", "0"); IndirectKVCache(); }; +class IndirectGemmOpt : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("IndirectGemmOpt", "0"); + IndirectGemmOpt(); +}; + +class IndirectSDPAOpt : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("IndirectSDPAOpt", "0"); + IndirectSDPAOpt(); +}; } // namespace intel_gpu } // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_sdpa.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_sdpa.cpp new file mode 100644 index 00000000000000..9b36bfcb3d3d32 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_sdpa.cpp @@ -0,0 +1,113 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "intel_gpu/op/indirect_sdpa.hpp" +#include "openvino/core/partial_shape.hpp" + +namespace ov { +namespace intel_gpu { +namespace op { + +IndirectSDPA::IndirectSDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const ov::Output& beam_table, + const bool is_causal, + const int64_t indirect_axis, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const ov::element::Type output_type) + : ov::intel_gpu::op::SDPA(Q, K, V, order_q, order_k, order_v, order_out, is_causal, output_type) + , m_indirect_axis(indirect_axis) { + set_argument(3, beam_table); + validate_and_infer_types(); +} + +IndirectSDPA::IndirectSDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const ov::Output& attn_mask, + const ov::Output& beam_table, + const bool is_causal, + const int64_t indirect_axis, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const ov::element::Type output_type) + : ov::intel_gpu::op::SDPA(Q, K, V, attn_mask, order_q, order_k, order_v, order_out, is_causal, output_type) + , m_indirect_axis(indirect_axis) { + set_argument(4, beam_table); + validate_and_infer_types(); +} + +IndirectSDPA::IndirectSDPA(const ov::Output& Q, + const ov::Output& K, + const ov::Output& V, + const ov::Output& attn_mask, + const ov::Output& scale, + const ov::Output& beam_table, + const bool is_causal, + const int64_t indirect_axis, + const std::vector& order_q, + const std::vector& order_k, + const std::vector& order_v, + const std::vector& order_out, + const ov::element::Type output_type) + : ov::intel_gpu::op::SDPA(Q, K, V, attn_mask, scale, order_q, order_k, order_v, order_out, is_causal, output_type) + , m_indirect_axis(indirect_axis) { + set_argument(5, beam_table); + validate_and_infer_types(); +} + +std::shared_ptr IndirectSDPA::clone_with_new_inputs(const ov::OutputVector& new_args) const { + check_new_args_count(this, new_args); + + if (new_args.size() == 4) { + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), + m_is_causal, m_indirect_axis, m_order_q, m_order_k, m_order_v, m_order_out, m_output_type); + } else if (new_args.size() == 5) { + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4), + m_is_causal, m_indirect_axis, m_order_q, m_order_k, m_order_v, m_order_out, m_output_type); + } else { + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4), new_args.at(5), + m_is_causal, m_indirect_axis, m_order_q, m_order_k, m_order_v, m_order_out, m_output_type); + } +} + +void IndirectSDPA::validate_and_infer_types() { + const auto input_size = get_input_size(); + NODE_VALIDATION_CHECK(this, + input_size == 4 || input_size == 5 || input_size == 6, + "Number of inputs is incorrect. Current value is: ", + input_size, + ", expected 4, 5 or 6."); + + std::vector input_shapes; + for (size_t i = 0; i < input_size - 1; i++) { + input_shapes.push_back(get_input_partial_shape(i)); + } + + auto out_shapes = shape_infer(this, + input_shapes, + m_order_q, + m_order_k, + m_order_v, + m_order_out); + + auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; + set_output_type(0, output_type, out_shapes[0]); +} + +bool IndirectSDPA::visit_attributes(ov::AttributeVisitor &visitor) { + SDPA::visit_attributes(visitor); + visitor.on_attribute("indirect_axis", m_indirect_axis); + return true; +} + +} // namespace op +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp index 67e927abb43f97..5965cb5e9910ba 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp @@ -30,6 +30,7 @@ SDPA::SDPA(const ov::Output& Q, , m_is_causal(is_causal) , m_output_type(output_type) { set_arguments({Q, K, V}); + set_causal(is_causal); validate_and_infer_types(); } @@ -50,6 +51,7 @@ SDPA::SDPA(const ov::Output& Q, , m_is_causal(is_causal) , m_output_type(output_type) { set_arguments({Q, K, V, attn_mask}); + set_causal(is_causal); validate_and_infer_types(); } @@ -71,13 +73,23 @@ SDPA::SDPA(const ov::Output& Q, , m_is_causal(is_causal) , m_output_type(output_type) { set_arguments({Q, K, V, attn_mask, scale}); + set_causal(is_causal); validate_and_infer_types(); } std::shared_ptr SDPA::clone_with_new_inputs(const ov::OutputVector& new_args) const { check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_order_q, m_order_k, m_order_v, m_order_out, m_is_causal, m_output_type); + if (new_args.size() == 3) { + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), + m_order_q, m_order_k, m_order_v, m_order_out, m_is_causal, m_output_type); + } else if (new_args.size() == 4) { + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), + m_order_q, m_order_k, m_order_v, m_order_out, m_is_causal, m_output_type); + } else { + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4), + m_order_q, m_order_k, m_order_v, m_order_out, m_is_causal, m_output_type); + } } void SDPA::validate_and_infer_types() { diff --git a/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.cpp index 614a42845ec521..f418376d1453c0 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.cpp @@ -80,14 +80,7 @@ TransposeSDPAMatcher::TransposeSDPAMatcher() { ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); - std::shared_ptr sdpa; - if (pattern_map.find(sdpa_without_attn_mask_m) != pattern_map.end()) { - sdpa = std::dynamic_pointer_cast(pattern_map.at(sdpa_without_attn_mask_m).get_node_shared_ptr()); - } else if (pattern_map.find(sdpa_with_attn_mask_m) != pattern_map.end()) { - sdpa = std::dynamic_pointer_cast(pattern_map.at(sdpa_with_attn_mask_m).get_node_shared_ptr()); - } else if (pattern_map.find(sdpa_with_attn_mask_and_scale_m) != pattern_map.end()) { - sdpa = std::dynamic_pointer_cast(pattern_map.at(sdpa_with_attn_mask_and_scale_m).get_node_shared_ptr()); - } + auto sdpa = std::dynamic_pointer_cast(m.get_match_root()); if (!sdpa || transformation_callback(sdpa)) { return false; @@ -101,33 +94,41 @@ TransposeSDPAMatcher::TransposeSDPAMatcher() { size_t input_k_output_idx = sdpa->get_input_source_output(1).get_index(); size_t input_v_output_idx = sdpa->get_input_source_output(2).get_index(); - if (pattern_map.count(transpose_q_m) > 0) { - auto tranpose_a_order = std::dynamic_pointer_cast(pattern_map.at(transpose_q_order_m).get_node_shared_ptr()); - order_q = tranpose_a_order->cast_vector(); - if (order_q.back() != static_cast(order_q.size() - 1)) // Allow any transposes without head_size dim position change - return false; + auto process_transpose = [](const std::shared_ptr& transpose_node, + const std::shared_ptr& transpose_order_const_node, + std::vector& order, + size_t& output_idx) { + auto transpose_order_const = std::dynamic_pointer_cast(transpose_order_const_node); - auto tranpose_a = std::dynamic_pointer_cast(pattern_map.at(transpose_q_m).get_node_shared_ptr()); - input_q_output_idx = tranpose_a->get_input_source_output(0).get_index(); - } - if (pattern_map.count(transpose_k_m) > 0) { - auto tranpose_b_order = std::dynamic_pointer_cast(pattern_map.at(transpose_k_order_m).get_node_shared_ptr()); - order_k = tranpose_b_order->cast_vector(); - if (order_k.back() != static_cast(order_k.size() - 1)) // Allow any transposes without head_size dim position change + order = transpose_order_const->cast_vector(); + // Allow any transposes without head_size dim position change + if (order.back() != static_cast(order.size() - 1)) return false; - auto tranpose_b = std::dynamic_pointer_cast(pattern_map.at(transpose_k_m).get_node_shared_ptr()); - input_k_output_idx = tranpose_b->get_input_source_output(0).get_index(); - } - if (pattern_map.count(transpose_v_m) > 0) { - auto tranpose_c_order = std::dynamic_pointer_cast(pattern_map.at(transpose_v_order_m).get_node_shared_ptr()); - order_v = tranpose_c_order->cast_vector(); - if (order_v.back() != static_cast(order_v.size() - 1)) // Allow any transposes without head_size dim position change - return false; + auto transpose = std::dynamic_pointer_cast(transpose_node); + output_idx = transpose->get_input_source_output(0).get_index(); - auto tranpose_c = std::dynamic_pointer_cast(pattern_map.at(transpose_k_m).get_node_shared_ptr()); - input_v_output_idx = tranpose_c->get_input_source_output(0).get_index(); - } + return true; + }; + + bool can_fuse_transposes = true; + if (pattern_map.count(transpose_q_m) > 0) + can_fuse_transposes &= process_transpose(pattern_map.at(transpose_q_m).get_node_shared_ptr(), + pattern_map.at(transpose_q_order_m).get_node_shared_ptr(), + order_q, input_q_output_idx); + + if (pattern_map.count(transpose_k_m) > 0) + can_fuse_transposes &= process_transpose(pattern_map.at(transpose_k_m).get_node_shared_ptr(), + pattern_map.at(transpose_k_order_m).get_node_shared_ptr(), + order_k, input_k_output_idx); + + if (pattern_map.count(transpose_v_m) > 0) + can_fuse_transposes &= process_transpose(pattern_map.at(transpose_v_m).get_node_shared_ptr(), + pattern_map.at(transpose_v_order_m).get_node_shared_ptr(), + order_v, input_v_output_idx); + + if (!can_fuse_transposes) + return false; auto input_q = ov::Output(pattern_map.at(input_q_m).get_node_shared_ptr(), input_q_output_idx); auto input_k = ov::Output(pattern_map.at(input_k_m).get_node_shared_ptr(), input_k_output_idx); diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 4c38310fdb8085..0c1041b742c0fb 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -13,6 +13,7 @@ #include #include "intel_gpu/plugin/transformations_pipeline.hpp" +#include "intel_gpu/runtime/debug_configuration.hpp" #include "intel_gpu/runtime/itt.hpp" #include "low_precision/convolution.hpp" #include "low_precision/convolution_backprop_data.hpp" @@ -307,6 +308,10 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); pass_config->set_callback([&](const std::shared_ptr node){ + GPU_DEBUG_IF(cldnn::debug_configuration::get_instance()->enable_sdpa != -1) { + GPU_DEBUG_CODE(return cldnn::debug_configuration::get_instance()->enable_sdpa == 1); + } + if (!config.get_property(ov::intel_gpu::hint::enable_sdpa_optimization)) return false; @@ -316,10 +321,6 @@ void TransformationsPipeline::apply(std::shared_ptr func) { const auto& value_ps = sdpa->get_input_partial_shape(2); // Known limitations: - // - SDPA impl could be slower in non-LLM scenarios than decomposed version - if (func->get_variables().size() == 0) - return false; - // - The data type of SDPA should be fp16 if (sdpa->get_output_element_type(0) != ov::element::f16) return false; @@ -342,7 +343,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // - The head size should be divisible by 16 const auto optimal_subgroup_size = 16; if (query_ps[query_ps.size() - 1].is_dynamic() || - query_ps[query_ps.size() - 1].get_length() > 256 || + query_ps[query_ps.size() - 1].get_length() != 128 || query_ps[query_ps.size() - 1].get_length() % optimal_subgroup_size != 0) { return false; } diff --git a/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp b/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp index d1130bd2d75694..fc9c16014707f3 100644 --- a/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp +++ b/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp @@ -171,6 +171,8 @@ static void print_help_messages() { message_list.emplace_back("OV_GPU_DisableDynamicImpl", "Disable dynamic implementation"); message_list.emplace_back("OV_GPU_DisableRuntimeBufferFusing", "Disable runtime buffer fusing"); message_list.emplace_back("OV_GPU_DisableMemoryReuse", "Disable memory reuse"); + message_list.emplace_back("OV_GPU_EnableSDPA", "This allows the enforcement of SDPA decomposition logic: 0 completely disables SDPA kernel usage, " + "and 1 enables it for all the cases."); message_list.emplace_back("OV_GPU_DumpMemoryPool", "Dump memory pool contents of each iteration"); message_list.emplace_back("OV_GPU_DumpMemoryPoolIters", "List of iterations to dump memory pool status, separated by space."); message_list.emplace_back("OV_GPU_DumpMemoryPoolPath", "Enable dumping memory pool status to csv file and set the dest path"); @@ -232,6 +234,7 @@ debug_configuration::debug_configuration() , serialize_compile(0) , max_kernels_per_batch(0) , impls_cache_capacity(-1) + , enable_sdpa(-1) , disable_async_compilation(0) , disable_winograd_conv(0) , disable_dynamic_impl(0) @@ -280,6 +283,7 @@ debug_configuration::debug_configuration() get_gpu_debug_env_var("ForceImplTypes", forced_impl_types_str); get_gpu_debug_env_var("MaxKernelsPerBatch", max_kernels_per_batch); get_gpu_debug_env_var("ImplsCacheCapacity", impls_cache_capacity); + get_gpu_debug_env_var("EnableSDPA", enable_sdpa); get_gpu_debug_env_var("DisableAsyncCompilation", disable_async_compilation); get_gpu_debug_env_var("DisableWinogradConv", disable_winograd_conv); get_gpu_debug_env_var("DisableDynamicImpl", disable_dynamic_impl); diff --git a/src/plugins/intel_gpu/src/runtime/execution_config.cpp b/src/plugins/intel_gpu/src/runtime/execution_config.cpp index 66b8d3e70cab1f..b0edfe39c90181 100644 --- a/src/plugins/intel_gpu/src/runtime/execution_config.cpp +++ b/src/plugins/intel_gpu/src/runtime/execution_config.cpp @@ -50,7 +50,7 @@ void ExecutionConfig::set_default() { std::make_tuple(ov::intel_gpu::hint::host_task_priority, ov::hint::Priority::MEDIUM), std::make_tuple(ov::intel_gpu::hint::queue_throttle, ov::intel_gpu::hint::ThrottleLevel::MEDIUM), std::make_tuple(ov::intel_gpu::hint::queue_priority, ov::hint::Priority::MEDIUM), - std::make_tuple(ov::intel_gpu::hint::enable_sdpa_optimization, false), + std::make_tuple(ov::intel_gpu::hint::enable_sdpa_optimization, true), std::make_tuple(ov::intel_gpu::enable_loop_unrolling, true), std::make_tuple(ov::intel_gpu::disable_winograd_convolution, false), std::make_tuple(ov::internal::exclusive_async_requests, false), diff --git a/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp index 3b97cde5cfe636..15203e9c5f26bf 100644 --- a/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp +++ b/src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp @@ -106,21 +106,9 @@ void ScaledAttnLayerGPUTest::SetUp() { } } - // Add artificial read/value operations to the model to trigger the enabling of the SDPA operation - auto read_key = std::make_shared(inputParams.at(1), "v0"); - auto assign_key = std::make_shared(read_key, "v0"); - - auto read_value = std::make_shared(inputParams.at(2), "v0"); - auto assign_value = std::make_shared(read_value, "v0"); - ov::OutputVector inputs; for (size_t i = 0; i < inputParams.size(); i++) { - if (i == 1) - inputs.push_back(read_key); - else if (i == 2) - inputs.push_back(read_value); - else - inputs.push_back(inputParams[i]); + inputs.push_back(inputParams[i]); } auto sdp = std::make_shared(inputs, is_causal); @@ -128,7 +116,7 @@ void ScaledAttnLayerGPUTest::SetUp() { auto output = std::make_shared(sdp->output(0)); - function = std::make_shared(ov::OutputVector{output}, ov::SinkVector{assign_key, assign_value}, inputParams, "sdpa_model"); + function = std::make_shared(ov::OutputVector{output}, inputParams, "sdpa_model"); functionRefs = function->clone(); ov::pass::Manager manager; @@ -137,11 +125,8 @@ void ScaledAttnLayerGPUTest::SetUp() { manager.register_pass(); manager.run_passes(functionRefs); - // Enable SDPA - configuration.insert(ov::intel_gpu::hint::enable_sdpa_optimization(true)); - auto it = std::find_if(inputShapes[1].second.begin(), inputShapes[1].second.end(), [&](const ov::Shape& shape){ - return shape[2] >= 384; + return shape[2] >= 384 || shape[3] >= 128; }); bool has_long_seq = it != inputShapes[1].second.end(); @@ -190,12 +175,12 @@ const std::vector> shapes{ // normal case, shapes of q,k,v are same { // q shape - {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, - {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 128}, + {ov::Shape{1, 8, 100, 128}, ov::Shape{1, 8, 1, 128}, ov::Shape{2, 8, 10, 128}}} }, // kv shape - {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, - {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 128}, + {ov::Shape{1, 8, 100, 128}, ov::Shape{1, 8, 1, 128}, ov::Shape{2, 8, 10, 128}}} }, // attn shape: [B, 1, -1, L0+L1] {ov::test::InputShape{ov::PartialShape{-1, 1, -1, -1}, @@ -204,12 +189,12 @@ const std::vector> shapes{ }, { // q shape - {ov::test::InputShape{ov::PartialShape{-1, 5, -1, 64}, - {ov::Shape{2, 5, 100, 64}, ov::Shape{2, 5, 1, 64}, ov::Shape{2, 5, 384, 64}}} + {ov::test::InputShape{ov::PartialShape{-1, 5, -1, 128}, + {ov::Shape{2, 5, 100, 128}, ov::Shape{2, 5, 1, 128}, ov::Shape{2, 5, 384, 128}}} }, // kv shape - {ov::test::InputShape{ov::PartialShape{-1, 5, -1, 64}, - {ov::Shape{2, 5, 100, 64}, ov::Shape{2, 5, 1, 64}, ov::Shape{2, 5, 384, 64}}} + {ov::test::InputShape{ov::PartialShape{-1, 5, -1, 128}, + {ov::Shape{2, 5, 100, 128}, ov::Shape{2, 5, 1, 128}, ov::Shape{2, 5, 384, 128}}} }, // attn shape: [B, 1, -1, L0+L1] {ov::test::InputShape{ov::PartialShape{-1, 1, -1, -1}, @@ -219,12 +204,12 @@ const std::vector> shapes{ // heads number of kv is 1, attn mask: [B, H, L1, L0+L1] { // q shape - {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64}, - {ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}} + {ov::test::InputShape{ov::PartialShape{-1, 8, -1, 128}, + {ov::Shape{1, 8, 100, 128}, ov::Shape{1, 8, 1, 128}, ov::Shape{2, 8, 10, 128}}} }, // kv shape - {ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64}, - {ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}} + {ov::test::InputShape{ov::PartialShape{-1, 1, -1, 128}, + {ov::Shape{1, 1, 100, 128}, ov::Shape{1, 1, 1, 128}, ov::Shape{2, 1, 10, 128}}} }, // attn shape {ov::test::InputShape{ov::PartialShape{-1, 8, -1, -1},