diff --git a/CITATION.cff b/CITATION.cff index 1598115c9a1..9a699952093 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -18,4 +18,4 @@ abstract: >- oneDNN has experimental support for the following architectures: NVIDIA GPU, AMD GPU, OpenPOWER Power ISA (PPC64), IBMz (s390x), and RISC-V. license: Apache-2.0 -version: v3.6 +version: v3.7 diff --git a/CMakeLists.txt b/CMakeLists.txt index af2522a0721..60dd94ed348 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,7 +72,7 @@ endif() set(PROJECT_NAME "oneDNN") set(PROJECT_FULL_NAME "oneAPI Deep Neural Network Library (oneDNN)") -set(PROJECT_VERSION "3.6.0") +set(PROJECT_VERSION "3.7.0") if (CMAKE_VERSION VERSION_LESS 3.0) project(${PROJECT_NAME} C CXX) diff --git a/cmake/TBB.cmake b/cmake/TBB.cmake index d6bbe3e8017..7c82c428b41 100644 --- a/cmake/TBB.cmake +++ b/cmake/TBB.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2018-2022 Intel Corporation +# Copyright 2018-2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ macro(handle_tbb_target) if(TBB_FOUND) set_property(TARGET TBB::tbb PROPERTY "MAP_IMPORTED_CONFIG_RELWITHMDD" "DEBUG") include_directories_with_host_compiler(${_tbb_include_dirs}) - list(APPEND EXTRA_SHARED_LIBS ${TBB_IMPORTED_TARGETS}) + list(APPEND EXTRA_SHARED_LIBS TBB::tbb) # Print TBB location get_filename_component(_tbb_root "${_tbb_include_dirs}" PATH) diff --git a/doc/graph/images/sdpa.png b/doc/graph/images/sdpa.png index bd961a5b582..87f4443bf49 100644 Binary files a/doc/graph/images/sdpa.png and b/doc/graph/images/sdpa.png differ diff --git a/doc/graph/sdpa.md b/doc/graph/sdpa.md index 3024c624916..1b0864a5c76 100644 --- a/doc/graph/sdpa.md +++ b/doc/graph/sdpa.md @@ -40,11 +40,11 @@ optional. 1. The first MatMul calculates the dot products between Query and Key. See [MatMul](@ref dev_guide_op_matmul) operation in Graph API. -2. The Scale node scales the output of the first MatMul with a scaling factor. - It can be constructed by [Multiply](@ref dev_guide_op_multiply) or - [Divide](@ref dev_guide_op_divide) operation in Graph API. The scaling factor - is given by users as an input of SDPA. \f$\sqrt{d_k}\f$ in the formula is not - considered as part of the SDPA pattern as it is constant. +2. The Scale node is optional and is used to scale the output of the first + MatMul with a scaling factor. It can be constructed by [Multiply](@ref dev_guide_op_multiply) + or [Divide](@ref dev_guide_op_divide) operation in Graph API. The scaling + factor is given by users as an input of SDPA. \f$\sqrt{d_k}\f$ in the formula + is not considered as part of the SDPA pattern as it is constant. 3. The Mask node is optional and is used to apply an attention mask to the output of the previous Scale node. It can be constructed by [Add](@ref dev_guide_op_add) or [Select](@ref dev_guide_op_select) operation in Graph API for different diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f9b138c2800..2d848af454a 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -89,6 +89,7 @@ if(NOT ONEDNN_BUILD_GRAPH) ${CMAKE_CURRENT_SOURCE_DIR}/graph/sdpa.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph/mqa.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph/sdpa_stacked_qkv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/graph/gqa.cpp ) endif() diff --git a/examples/graph/gqa.cpp b/examples/graph/gqa.cpp new file mode 100644 index 00000000000..3c60bfd1d9d --- /dev/null +++ b/examples/graph/gqa.cpp @@ -0,0 +1,351 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_graph.hpp" + +#include "graph_example_utils.hpp" + +using namespace dnnl; +using tag = memory::format_tag; + +using namespace dnnl::graph; +using layout_type = logical_tensor::layout_type; +using dim = logical_tensor::dim; +using dims = logical_tensor::dims; + +struct gqa_dims_t { + dim mb; + dim seq_len; + dim q_head_num; + dim kv_head_num; + dim head_size; +}; + +static const int min_runs = 4; + +// this is changed from the fill_random() function in matmul_perf.cpp. +void fill_random(std::vector &out) { + static std::vector random_data_f; + constexpr size_t nrand = 1037; + + if (random_data_f.empty()) { + std::mt19937 generator; + std::uniform_real_distribution dist_f(-1.0f, 1.0f); + + random_data_f.resize(nrand); + for (auto &d : random_data_f) + d = dist_f(generator); + } + + for (size_t i = 0; i < out.size(); i += nrand) { + size_t chunk = std::min(nrand, out.size() - i); + std::memcpy(&out[i], random_data_f.data(), chunk * sizeof(float)); + } +} + +// initialize the mask with first 3/4 elements with 0s and the last 1/4 elements +// with -inf. +void fill_mask(std::vector &mask, size_t seq_len) { + const size_t pos = seq_len * 3 / 4; + for (size_t i = 0; i < mask.size(); ++i) { + if (i % seq_len < pos) + mask[i] = 0.f; + else + mask[i] = -1 * std::numeric_limits::infinity(); + } +} + +const char *get_type_string(logical_tensor::data_type dt) { + const char *type_string = "unknown"; + +#define TYPE_CASE(T) \ + if (dt == logical_tensor::data_type::T) type_string = #T; + TYPE_CASE(f16); + TYPE_CASE(f32); + TYPE_CASE(bf16); +#undef TYPE_CASE + + return type_string; +} + +void print_test_case(logical_tensor::data_type dt, const gqa_dims_t &p) { + std::cout << '[' << std::setw(4) << get_type_string(dt); + std::cout << " mb = " << p.mb << ", seq_len = " << p.seq_len + << ", q_head_num = " << p.q_head_num + << ", kv_head_num = " << p.kv_head_num + << ", head_size = " << p.head_size; + std::cout << "] " << std::flush; +} + +void bench_gqa(engine::kind ekind, logical_tensor::data_type dt, + const gqa_dims_t &p, double time_limit = 0.) { + const bool quick_test = (time_limit == 0.); + print_test_case(dt, p); + + allocator alloc = create_allocator(ekind); + + // Create execution dnnl::engine. + dnnl::engine eng = make_engine_with_allocator(ekind, 0, alloc); + // Create dnnl::stream. + dnnl::stream strm(eng); + dnnl_dim_t head_rep = p.q_head_num / p.kv_head_num; + // Prepare input and output shapes to construct the gqa graph. + const dims q_sz = {p.mb, p.q_head_num, p.seq_len, p.head_size}; + const dims q_sz_reshape + = {p.mb, p.kv_head_num, head_rep, p.seq_len, p.head_size}; + const dims kv_sz = {p.mb, p.kv_head_num, p.seq_len, p.head_size}; + const dims kv_sz_reshape = {p.mb, p.kv_head_num, 1, p.seq_len, p.head_size}; + const dims score_sz = {p.mb, p.kv_head_num, head_rep, p.seq_len, p.seq_len}; + const dims scale_sz = {1}; + const dims mask_sz = {p.mb, 1, 1, p.seq_len}; + const dims mask_sz_reshape = {p.mb, 1, 1, 1, p.seq_len}; + + // Incremental IDs used to create logical tensors and operations. + size_t id = 0; + + // score = query x key.T + auto query = logical_tensor(id++, dt, q_sz, layout_type::strided); + auto query_reshape + = logical_tensor(id++, dt, q_sz_reshape, layout_type::strided); + auto key = logical_tensor(id++, dt, kv_sz, layout_type::strided); + auto key_reshape + = logical_tensor(id++, dt, kv_sz_reshape, layout_type::strided); + auto score = logical_tensor(id++, dt, score_sz, layout_type::strided); + + auto reshape1 = op(id++, op::kind::StaticReshape, "reshape1"); + reshape1.set_attr(op::attr::shape, q_sz_reshape); + reshape1.set_attr(op::attr::special_zero, false); + reshape1.add_inputs({query}); + reshape1.add_outputs({query_reshape}); + + auto reshape2 = op(id++, op::kind::StaticReshape, "reshape2"); + reshape2.set_attr(op::attr::shape, kv_sz_reshape); + reshape2.set_attr(op::attr::special_zero, false); + reshape2.add_inputs({key}); + reshape2.add_outputs({key_reshape}); + + auto bmm1 = op(id++, op::kind::MatMul, "bmm1"); + bmm1.set_attr(op::attr::transpose_b, true); + bmm1.add_inputs({query_reshape, key_reshape}); + bmm1.add_outputs({score}); + + // scaled_score = score / scale + auto scale = logical_tensor(id++, dt, scale_sz, layout_type::strided); + auto scaled_score + = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto scale_div = op(id++, op::kind::Divide, "scale_div"); + scale_div.add_inputs({score, scale}); + scale_div.add_outputs({scaled_score}); + + // masked_score = scaled_score + mask + auto mask = logical_tensor(id++, dt, mask_sz, layout_type::strided); + auto mask_reshape + = logical_tensor(id++, dt, mask_sz_reshape, layout_type::strided); + auto reshape3 = op(id++, op::kind::StaticReshape, "reshape3"); + reshape3.set_attr(op::attr::shape, mask_sz_reshape); + reshape3.set_attr(op::attr::special_zero, false); + reshape3.add_inputs({mask}); + reshape3.add_outputs({mask_reshape}); + + auto masked_score + = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto mask_add = op(id++, op::kind::Add, "mask_add"); + mask_add.add_inputs({scaled_score, mask_reshape}); + mask_add.add_outputs({masked_score}); + + // attention_probs = softmax(masked_score) + auto probs = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto softmax = op(id++, op::kind::SoftMax, "softmax"); + softmax.set_attr(op::attr::axis, -1); + softmax.add_inputs({masked_score}); + softmax.add_outputs({probs}); + + // attention_output = attention_probs x value + auto value = logical_tensor(id++, dt, kv_sz, layout_type::strided); + auto value_reshape + = logical_tensor(id++, dt, kv_sz_reshape, layout_type::strided); + + auto output_reshape + = logical_tensor(id++, dt, q_sz_reshape, layout_type::strided); + + auto reshape4 = op(id++, op::kind::StaticReshape, "reshape3"); + reshape4.set_attr(op::attr::shape, kv_sz_reshape); + reshape4.set_attr(op::attr::special_zero, false); + reshape4.add_inputs({value}); + reshape4.add_outputs({value_reshape}); + + auto bmm2 = op(id++, op::kind::MatMul, "bmm2"); + bmm2.add_inputs({probs, value_reshape}); + bmm2.add_outputs({output_reshape}); + + auto output = logical_tensor(id++, dt, q_sz, layout_type::strided); + auto reshape5 = op(id++, op::kind::StaticReshape, "reshape4"); + reshape5.set_attr(op::attr::shape, q_sz); + reshape5.set_attr(op::attr::special_zero, false); + reshape5.add_inputs({output_reshape}); + reshape5.add_outputs({output}); + + // Construct a gqa graph with engine kind and operations. + dnnl::graph::graph gqa(ekind); + gqa.add_op(reshape1); + gqa.add_op(reshape2); + gqa.add_op(bmm1); + gqa.add_op(scale_div); + gqa.add_op(reshape3); + gqa.add_op(mask_add); + gqa.add_op(softmax); + gqa.add_op(reshape4); + gqa.add_op(bmm2); + gqa.add_op(reshape5); + gqa.finalize(); + + // Get partitions from the gqa graph. + std::vector partitions = gqa.get_partitions(); + // This is just for oneDNN testing purpose. + if (partitions.size() != 1) { + std::cout << "unsupported gqa" << std::endl; + return; + } + + // Compile the partition with inputs, outputs, and an engine. + compiled_partition cp = partitions[0].compile( + {query, key, scale, mask, value}, {output}, eng); + + // Create tensor objects + auto ts_query = tensor(query, eng); + auto ts_key = tensor(key, eng); + auto ts_scale = tensor(scale, eng); + auto ts_mask = tensor(mask, eng); + auto ts_value = tensor(value, eng); + auto ts_output = tensor(output, eng); + + // Allocate user data. + std::vector query_data(product(q_sz)); + std::vector key_data(product(kv_sz)); + std::vector scale_data(product(scale_sz), std::sqrt(p.head_size)); + std::vector mask_data(product(mask_sz)); + std::vector value_data(product(kv_sz)); + std::vector output_data(product(kv_sz)); + + fill_random(query_data); + fill_random(key_data); + fill_random(value_data); + fill_mask(mask_data, static_cast(p.seq_len)); + + // Write data to tensor object's handle. + write_to_dnnl_tensor(query_data.data(), ts_query); + write_to_dnnl_tensor(key_data.data(), ts_key); + write_to_dnnl_tensor(scale_data.data(), ts_scale); + write_to_dnnl_tensor(mask_data.data(), ts_mask); + write_to_dnnl_tensor(value_data.data(), ts_value); + + // Warmup run. + // Execute the compiled partition of mqa. + cp.execute( + strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, {ts_output}); + + // Wait for the computation to finish. + strm.wait(); + + // First run. + auto start_first = std::chrono::steady_clock::now(); + cp.execute( + strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, {ts_output}); + strm.wait(); + auto end_first = std::chrono::steady_clock::now(); + std::chrono::duration dur_first + = end_first - start_first; + + if (quick_test) return; + + // Timing runs. + const int runs = std::max(min_runs, int(time_limit / dur_first.count())); + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i <= runs; i++) + cp.execute(strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, + {ts_output}); + strm.wait(); + auto end = std::chrono::steady_clock::now(); + std::chrono::duration duration = end - start; + + // Display the results. + double avg_time = (duration.count() - dur_first.count()) / runs; + std::cout << "graph runs: " << runs + 1 << "; "; + std::cout << "avg_time: " << avg_time << " ms" << std::endl; +} + +void bad_args() { + std::cerr << "Usage: graph-gqa-cpp [cpu|gpu]\n" + " graph-gqa-cpp [cpu|gpu] " + " \n\n"; + throw std::invalid_argument("Incorrect input arguments."); +} + +void bench(engine::kind ekind, dnnl_data_type_t dt, const gqa_dims_t &p, + double time_limit = 0.) { + try { + bench_gqa(ekind, static_cast(dt), p, + time_limit); + get_mem_pool().clear(); + } catch (dnnl::error &e) { + // Catch and report unimplemented cases. + if (e.status == dnnl_unimplemented) { + std::cout << "unsupported gqa" << std::endl; + } else + throw; + } +} + +void gqa_perf(engine::kind ekind, int argc, char **argv) { + // default testing parameters + gqa_dims_t params = {32, 384, 16, 2, 64}; + + if (argc > 2) { + if (argc == 7) { + params.mb = std::atoi(argv[2]); + params.seq_len = std::atoi(argv[3]); + params.q_head_num = std::atoi(argv[4]); + params.kv_head_num = std::atoi(argv[5]); + params.head_size = std::atoi(argv[6]); + } else { + bad_args(); + } + + if (params.mb <= 0 || params.seq_len <= 0 || params.kv_head_num <= 0 + || params.q_head_num <= 0 || params.head_size <= 0) { + bad_args(); + } + } + + bench(ekind, dnnl_f32, params, 2000.0 /*ms*/); + bench(ekind, dnnl_bf16, params, 2000.0 /*ms*/); + bench(ekind, dnnl_f16, params, 2000.0 /*ms*/); +} + +int main(int argc, char **argv) { + return handle_example_errors( + gqa_perf, parse_engine_kind(argc, argv, 5), argc, argv); +} diff --git a/src/common/math_utils.hpp b/src/common/math_utils.hpp index 80a05279556..0c156dff8db 100644 --- a/src/common/math_utils.hpp +++ b/src/common/math_utils.hpp @@ -539,9 +539,13 @@ inline float stochastic_round_fwd( // TODO: NaN handling when dst_dt has no NaN if (std::isnan(s)) return s; + if (dst_dt == data_type::undef) return NAN; using namespace dnnl::impl::types; - assert(digits(data_type::f32) >= digits(dst_dt)); + if (digits(data_type::f32) < digits(dst_dt)) { + assert(!"dst_dt is a bad data type"); + return NAN; + } uint32_t truncation_mask = 0xffffffff << (digits(data_type::f32) - digits(dst_dt)); diff --git a/src/common/primitive_attr.hpp b/src/common/primitive_attr.hpp index d3f3edba48e..5e1496978ed 100644 --- a/src/common/primitive_attr.hpp +++ b/src/common/primitive_attr.hpp @@ -509,7 +509,7 @@ struct rnd_mode_t : public c_compatible { bool operator==(const rnd_mode_t &rhs) const { bool res = rounding_modes_map_.size() == rhs.rounding_modes_map_.size(); if (!res) return false; - for (auto e : rounding_modes_map_) + for (const auto &e : rounding_modes_map_) if (e.second != rhs.get(e.first)) return false; return true; } diff --git a/src/common/sdpa_types.hpp b/src/common/sdpa_types.hpp index 84ee79bcef5..03fc9f67aaa 100644 --- a/src/common/sdpa_types.hpp +++ b/src/common/sdpa_types.hpp @@ -38,6 +38,7 @@ struct sdpa_desc_t { // invert_scale = false: multiply by scale // invert_scale = true: divide by scale bool invert_scale; + dim_t kv_head_number; // Number of queries. dnnl_dim_t queries() const { return q_desc.dims[q_desc.ndims - 2]; } diff --git a/src/common/sdpa_utils.hpp b/src/common/sdpa_utils.hpp index 8dd166d6ed7..ccba17a2081 100644 --- a/src/common/sdpa_utils.hpp +++ b/src/common/sdpa_utils.hpp @@ -31,7 +31,7 @@ namespace impl { static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, const memory_desc_t *k_md, const memory_desc_t *v_md, const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, - data_type_t scale_dt, bool invert_scale = false) { + data_type_t scale_dt, dim_t kv_head_number, bool invert_scale = false) { auto sdpa_desc = sdpa_desc_t(); sdpa_desc.primitive_kind = primitive_kind::sdpa; sdpa_desc.q_desc = *q_md; @@ -41,6 +41,7 @@ static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, if (attn_mask_md) sdpa_desc.attn_mask_desc = *attn_mask_md; sdpa_desc.scale_dt = scale_dt; sdpa_desc.invert_scale = invert_scale; + sdpa_desc.kv_head_number = kv_head_number; return sdpa_desc; } @@ -49,9 +50,9 @@ static inline status_t create_sdpa_pd( const memory_desc_t *q_md, const memory_desc_t *k_md, const memory_desc_t *v_md, const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, data_type_t scale_dt, - bool invert_scale, const primitive_attr_t *attr) { - auto sdpa_desc = create_sdpa_desc( - q_md, k_md, v_md, dst_md, attn_mask_md, scale_dt, invert_scale); + bool invert_scale, const primitive_attr_t *attr, dim_t kv_head_number) { + auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md, + scale_dt, kv_head_number, invert_scale); int ndims = dst_md->ndims; int r = ndims - 2, c = ndims - 1; diff --git a/src/cpu/aarch64/jit_sve_conv_kernel.cpp b/src/cpu/aarch64/jit_sve_conv_kernel.cpp index ae6762d416c..3acec02fe29 100644 --- a/src/cpu/aarch64/jit_sve_conv_kernel.cpp +++ b/src/cpu/aarch64/jit_sve_conv_kernel.cpp @@ -1252,19 +1252,18 @@ void jit_sve_conv_bwd_data_kernel_f32::store_output(int ur_w) { int ofs = aux_output_offset; if ((VL_OFS(ofs, isa) < LDRMAX) && (VL_OFS(ofs, isa) >= (-1 * LDRMAX)) && ((ofs & 0x3f) == 0)) { - ldr(zreg_tmp(idx), - ptr(reg_src, static_cast(VL_OFS(ofs, isa)))); + add_imm(X_DEFAULT_ADDR, reg_src, ofs, X_TMP_0); + ld1w(zreg_tmp(idx).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); } else { int tmp_ofs = aux_output_offset - prev_ofs; if (((tmp_ofs & 0x3f) == 0) && (VL_OFS(tmp_ofs, isa) < LDRWMAX) && (tmp_ofs >= 0)) { - ldr(zreg_tmp(idx), - ptr(reg_tmp_addr, - static_cast(VL_OFS(tmp_ofs, isa)))); + add_imm(X_DEFAULT_ADDR, reg_tmp_addr, tmp_ofs, X_TMP_0); + ld1w(zreg_tmp(idx).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); } else { add_imm(reg_tmp_addr, reg_src, ofs, reg_tmp_imm); - ldr(zreg_tmp(idx), ptr(reg_tmp_addr)); + ld1w(zreg_tmp(idx).s, P_ALL_ONE / T_z, ptr(reg_tmp_addr)); prev_ofs = ofs; } } @@ -1276,19 +1275,20 @@ void jit_sve_conv_bwd_data_kernel_f32::store_output(int ur_w) { if ((VL_OFS(ofs, isa) < LDRMAX) && (VL_OFS(ofs, isa) >= (-1 * LDRMAX)) && ((ofs & 0x3f) == 0)) { - str(zreg_out(j, k), - ptr(reg_src, static_cast(VL_OFS(ofs, isa)))); + add_imm(X_DEFAULT_ADDR, reg_src, ofs, X_TMP_0); + st1w(zreg_out(j, k).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); + } else { int tmp_ofs = aux_output_offset - prev_ofs; if (((tmp_ofs & 0x3f) == 0) && (VL_OFS(tmp_ofs, isa) < LDRWMAX) && (tmp_ofs >= 0)) { - str(zreg_out(j, k), - ptr(reg_tmp_addr, - static_cast(VL_OFS(tmp_ofs, isa)))); + add_imm(X_DEFAULT_ADDR, reg_tmp_addr, tmp_ofs, X_TMP_0); + st1w(zreg_out(j, k).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); + } else { add_imm(reg_tmp_addr, reg_src, ofs, reg_tmp_imm); - str(zreg_out(j, k), ptr(reg_tmp_addr)); + st1w(zreg_out(j, k).s, P_ALL_ONE / T_z, ptr(reg_tmp_addr)); prev_ofs = ofs; } } @@ -1417,11 +1417,12 @@ void jit_sve_conv_bwd_data_kernel_f32::compute_loop_fma( if ((VL_OFS(ofs, isa) < LDRMAX) && (VL_OFS(ofs, isa) >= (-1 * LDRMAX)) && ((ofs & 0x3f) == 0)) { - ldr(zreg_ker(i), - ptr(aux_reg_ker, static_cast(VL_OFS(ofs, isa)))); + add_imm(X_DEFAULT_ADDR, aux_reg_ker, ofs, X_TMP_0); + ld1w(zreg_ker(i).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); + } else { add_imm(reg_tmp_addr, aux_reg_ker, ofs, reg_tmp_imm); - ldr(zreg_ker(i), ptr(reg_tmp_addr)); + ld1w(zreg_ker(i).s, P_ALL_ONE / T_z, ptr(reg_tmp_addr)); } }; @@ -1694,11 +1695,12 @@ void jit_sve_conv_bwd_data_kernel_f32::compute_loop_fma_core( if ((VL_OFS(ofs, isa) < LDRMAX) && (VL_OFS(ofs, isa) >= (-1 * LDRMAX))) { - ldr(zreg_wei(idx), - ptr(aux_reg_ker, static_cast(VL_OFS(ofs, isa)))); + add_imm(X_DEFAULT_ADDR, aux_reg_ker, ofs, X_TMP_0); + ld1w(zreg_wei(idx).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); + } else { add_imm(reg_tmp_addr, aux_reg_ker, ofs, reg_tmp_imm); - ldr(zreg_wei(idx), ptr(reg_tmp_addr)); + ld1w(zreg_wei(idx).s, P_ALL_ONE / T_z, ptr(reg_tmp_addr)); } }; @@ -1887,9 +1889,12 @@ void jit_sve_conv_bwd_data_kernel_f32::generate() { * (is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : oc_block); int src_shift = jcp.typesize_out * ur_w * (is_dsrc_layout_nxc() ? jcp.ngroups * jcp.ic : ic_block); + const int simd_w_ = cpu_isa_traits::vlen / sizeof(float); preamble(); + if (simd_w_ != cpu_sveLen / sizeof(float)) + set_preg(P_ALL_ONE.s, simd_w_, X_TMP_0, X_TMP_1); ldr(reg_src, ptr(param, GET_OFF(src))); ldr(reg_dst, ptr(param, GET_OFF(dst))); ldr(reg_ker, ptr(param, GET_OFF(filt))); @@ -2064,7 +2069,7 @@ template status_t jit_sve_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &diff_src_md, memory_desc_t &weights_md, memory_desc_t &diff_dst_md, int nthreads) { - if (!mayiuse(sve_512)) return status::unimplemented; + if (!mayiuse(isa)) return status::unimplemented; const memory_desc_wrapper diff_src_d(&diff_src_md); const memory_desc_wrapper weights_d(&weights_md); @@ -2138,14 +2143,14 @@ status_t jit_sve_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c); bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag); - if (mayiuse(sve_512) && is_data_layout_nxc) return status::unimplemented; + if (mayiuse(isa) && is_data_layout_nxc) return status::unimplemented; jcp.is_1stconv = false; bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1 && diff_src_d.data_type() == data_type::f32; - const int full_simd_w = cpu_isa_traits::vlen / typesize; + const int full_simd_w = cpu_isa_traits::vlen / typesize; jcp.simd_w = full_simd_w; jcp.oc_block = jcp.simd_w; @@ -2166,10 +2171,10 @@ status_t jit_sve_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, const auto nxc_tag = pick(ndims - 3, nwc, nhwc, ndhwc); if (jcp.simd_w == 8) { - assert(with_groups); - dat_tag = is_data_layout_nxc ? nxc_tag + dat_tag = is_data_layout_nxc ? pick(ndims - 3, nwc, nhwc, ndhwc) : pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); - wei_tag = pick(ndims - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i); + wei_tag = pick(2 * ndims - 6 + with_groups, OIw8o8i, gOIw8o8i, OIhw8o8i, + gOIhw8o8i, OIdhw8o8i, gOIdhw8o8i); } else if (jcp.simd_w == 4) { assert(with_groups); dat_tag = is_data_layout_nxc ? nxc_tag @@ -2223,7 +2228,7 @@ status_t jit_sve_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, int n_oi = jcp.iw / jcp.ur_w; if (r_overflow_no_tail > 0) n_oi--; - if (mayiuse(sve_512) && diff_dst_d.data_type() == data_type::f32 + if (mayiuse(isa) && diff_dst_d.data_type() == data_type::f32 && weights_d.data_type() == data_type::f32 && diff_src_d.data_type() == data_type::f32) { jcp.ver = ver_fma; @@ -2270,7 +2275,7 @@ status_t jit_sve_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, && jcp.stride_w == 1 && utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w); - if (jcp.ver == ver_fma && mayiuse(sve_512)) { + if (jcp.ver == ver_fma && mayiuse(isa)) { int try_nb_ic_blocking = 2; bool use_expl_bcast = !(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8) @@ -2500,18 +2505,19 @@ void jit_sve_conv_bwd_weights_kernel_f32::compute_ic_block_step(int ur_w, int oc_block = jcp.oc_block; auto load_ker = [=](int zreg_idx, int ofs, int pre_offset_ker) { - if (str_imm_check(ofs)) { - ldr(ZReg(zreg_idx), - ptr(reg_kernel, static_cast(VL_OFS(ofs, isa)))); + if (str_imm_check(ofs)) { + add_imm(X_DEFAULT_ADDR, reg_kernel, ofs, X_TMP_0); + ld1w(ZReg(zreg_idx).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); } else { - if (pre_offset_ker >= 0 && str_imm_check(ofs - pre_offset_ker)) { - ldr(ZReg(zreg_idx), - ptr(reg_pre_addr_ker, - static_cast( - VL_OFS((ofs - pre_offset_ker), isa)))); + if (pre_offset_ker >= 0 + && str_imm_check(ofs - pre_offset_ker)) { + add_imm(X_DEFAULT_ADDR, reg_pre_addr_ker, + (ofs - pre_offset_ker), X_TMP_0); + ld1w(ZReg(zreg_idx).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); + } else { add_imm(reg_pre_addr_ker, reg_kernel, ofs, reg_tmp_imm); - ldr(ZReg(zreg_idx), ptr(reg_pre_addr_ker)); + ld1w(ZReg(zreg_idx).s, P_ALL_ONE / T_z, ptr(reg_pre_addr_ker)); pre_offset_ker = ofs; } } @@ -2619,18 +2625,19 @@ void jit_sve_conv_bwd_weights_kernel_f32::compute_ic_block_step(int ur_w, int pre_offset_out = -1; auto load_out = [&](int zreg_idx, int ofs) { - if (ldr_imm_check(ofs)) { - ldr(ZReg(zreg_idx), - ptr(reg_output, static_cast(VL_OFS(ofs, isa)))); + if (ldr_imm_check(ofs)) { + add_imm(X_DEFAULT_ADDR, reg_output, ofs, X_TMP_0); + ld1w(ZReg(zreg_idx).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); } else { - if (pre_offset_out >= 0 && ldr_imm_check(ofs - pre_offset_out)) { - ldr(ZReg(zreg_idx), - ptr(reg_pre_addr_out, - static_cast( - VL_OFS((ofs - pre_offset_out), isa)))); + if (pre_offset_out >= 0 + && ldr_imm_check(ofs - pre_offset_out)) { + add_imm(X_DEFAULT_ADDR, reg_pre_addr_out, + (ofs - pre_offset_out), X_TMP_0); + ld1w(ZReg(zreg_idx).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); + } else { add_imm(reg_pre_addr_out, reg_output, ofs, reg_tmp_imm); - ldr(ZReg(zreg_idx), ptr(reg_pre_addr_out)); + ld1w(ZReg(zreg_idx).s, P_ALL_ONE / T_z, ptr(reg_pre_addr_out)); pre_offset_out = ofs; } } @@ -2751,18 +2758,18 @@ void jit_sve_conv_bwd_weights_kernel_f32::compute_ic_block_step(int ur_w, } auto store_ker = [=](int zreg_idx, int ofs, int pre_offset_ker) { - if (str_imm_check(ofs)) { - str(ZReg(zreg_idx), - ptr(reg_kernel, static_cast(VL_OFS(ofs, isa)))); + if (str_imm_check(ofs)) { + add_imm(X_DEFAULT_ADDR, reg_kernel, ofs, X_TMP_0); + st1w(ZReg(zreg_idx).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); } else { - if (pre_offset_ker >= 0 && str_imm_check(ofs - pre_offset_ker)) { - str(ZReg(zreg_idx), - ptr(reg_pre_addr_ker, - static_cast( - VL_OFS((ofs - pre_offset_ker), isa)))); + if (pre_offset_ker >= 0 + && str_imm_check(ofs - pre_offset_ker)) { + add_imm(X_DEFAULT_ADDR, reg_pre_addr_ker, + (ofs - pre_offset_ker), X_TMP_0); + st1w(ZReg(zreg_idx).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); } else { add_imm(reg_pre_addr_ker, reg_kernel, ofs, reg_tmp_imm); - str(ZReg(zreg_idx), ptr(reg_pre_addr_ker)); + st1w(ZReg(zreg_idx).s, P_ALL_ONE / T_z, ptr(reg_pre_addr_ker)); pre_offset_ker = ofs; } } @@ -3339,20 +3346,18 @@ void jit_sve_conv_bwd_weights_kernel_f32::maybe_zero_kernel() { mov(reg_tmp, 0); L(zeroing_loop); { - assert(jcp.oc_block * jcp.typesize_out - == cpu_isa_traits::vlen); + assert(jcp.oc_block * jcp.typesize_out == cpu_isa_traits::vlen); add(reg_ker_start_addr, reg_kernel, reg_tmp); for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { - if (str_imm_check(ic1 * jcp.oc_block * jcp.typesize_out)) { - str(ZReg(0), - ptr(reg_ker_start_addr, - static_cast(VL_OFS( - (ic1 * jcp.oc_block * jcp.typesize_out), - isa)))); + if (str_imm_check( + ic1 * jcp.oc_block * jcp.typesize_out)) { + add_imm(X_DEFAULT_ADDR, reg_ker_start_addr, + (ic1 * jcp.oc_block * jcp.typesize_out), X_TMP_0); + st1w(ZReg(0).s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); } else { add_imm(reg_add_tmp, reg_ker_start_addr, ic1 * jcp.oc_block * jcp.typesize_out, reg_tmp_imm); - str(ZReg(0), ptr(reg_add_tmp)); + st1w(ZReg(0).s, P_ALL_ONE / T_z, ptr(reg_add_tmp)); } } @@ -3383,14 +3388,14 @@ void jit_sve_conv_bwd_weights_kernel_f32::bias_kernel_2d() { tst(reg_tmp, reg_tmp); b(NE, skip_bias); - ldr(ZReg(0), ptr(reg_bias)); + ld1w(ZReg(0).s, P_ALL_ONE / T_z, ptr(reg_bias)); mov_imm(reg_oi, jcp.ow); mov(reg_tmp, 0); L(bias_loop); { add(reg_add_tmp, reg_output, reg_tmp); - ldr(ZReg(1), ptr(reg_add_tmp)); + ld1w(ZReg(1).s, P_ALL_ONE / T_z, ptr(reg_add_tmp)); fadd(ZRegS(0), ZRegS(0), ZRegS(1)); const int oc_stride = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block; @@ -3398,7 +3403,7 @@ void jit_sve_conv_bwd_weights_kernel_f32::bias_kernel_2d() { subs(reg_oi, reg_oi, 1); b(GT, bias_loop); } - str(ZReg(0), ptr(reg_bias)); + st1w(ZReg(0).s, P_ALL_ONE / T_z, ptr(reg_bias)); L(skip_bias); } @@ -3419,7 +3424,7 @@ void jit_sve_conv_bwd_weights_kernel_f32::bias_kernel_3d() { ldr(reg_tmp, ptr(param, GET_OFF(channel))); cmp(reg_tmp, 0); b(NE, skip_load_bias); - ldr(ZReg(1), ptr(reg_bias)); + ld1w(ZReg(1).s, P_ALL_ONE / T_z, ptr(reg_bias)); L(skip_load_bias); @@ -3437,13 +3442,14 @@ void jit_sve_conv_bwd_weights_kernel_f32::bias_kernel_3d() { L(bias_loop); { add(reg_add_tmp, reg_output, reg_tmp); - ldr(ZReg(0), ptr(reg_add_tmp)); + ld1w(ZReg(0).s, P_ALL_ONE / T_z, ptr(reg_add_tmp)); + fadd(ZRegS(1), ZRegS(1), ZRegS(0)); add_imm(reg_tmp, reg_tmp, oc_mult * jcp.typesize_out, reg_tmp_imm); cmp(reg_tmp, reg_oi); b(LT, bias_loop); } - str(ZReg(1), ptr(reg_bias)); + st1w(ZReg(1).s, P_ALL_ONE / T_z, ptr(reg_bias)); L(skip_bias); } @@ -3451,6 +3457,7 @@ void jit_sve_conv_bwd_weights_kernel_f32::bias_kernel_3d() { template void jit_sve_conv_bwd_weights_kernel_f32::compute_oh_loop_common() { assert(one_of(jcp.harness, harness_mb_reduction, harness_3d_reduction)); + int b_pad = jcp.b_pad; int t_pad = jcp.t_pad; bool is_dilated = jcp.dilate_h != 0; @@ -3686,8 +3693,8 @@ void jit_sve_conv_bwd_weights_kernel_f32::compute_oh_loop_partial() { tst(reg_tmp, reg_tmp); b(NE, skip_zero_bias); eor(ZRegS(1), P_ALL_ONE.b, ZRegS(1)); - str(ZReg(1), - ptr(reg_bias)); //vmovups(ptr[reg_bias], Zmm(1)); + st1w(ZReg(1).s, P_ALL_ONE / T_z, ptr(reg_bias)); + L(skip_zero_bias); } @@ -3945,8 +3952,13 @@ void jit_sve_conv_bwd_weights_kernel_f32::compute_loop() { template void jit_sve_conv_bwd_weights_kernel_f32::generate_kernel() { + const int simd_w_ = cpu_isa_traits::vlen / sizeof(float); + preamble(); + if (simd_w_ != cpu_sveLen / sizeof(float)) + set_preg(P_ALL_ONE.s, simd_w_, X_TMP_0, X_TMP_1); + ldr(reg_input, ptr(param, GET_OFF(src))); ldr(reg_output, ptr(param, GET_OFF(dst))); ldr(reg_kernel, ptr(param, GET_OFF(filt))); @@ -3961,7 +3973,7 @@ status_t jit_sve_conv_bwd_weights_kernel_f32::init_conf( jit_conv_conf_t &jcp, const convolution_desc_t &cd, memory_desc_t &src_md, memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) { - if (!mayiuse(sve_512)) return status::unimplemented; + if (!mayiuse(isa)) return status::unimplemented; const memory_desc_wrapper src_d(&src_md); const memory_desc_wrapper diff_weights_d(&diff_weights_md); @@ -3973,7 +3985,7 @@ status_t jit_sve_conv_bwd_weights_kernel_f32::init_conf( jcp = zero(); - jcp.simd_w = cpu_isa_traits::vlen / typesize; + jcp.simd_w = cpu_isa_traits::vlen / typesize; jcp.nthr = jcp.aligned_threads = nthreads; jcp.ndims = ndims; jcp.prop_kind = cd.prop_kind; @@ -4041,14 +4053,15 @@ status_t jit_sve_conv_bwd_weights_kernel_f32::init_conf( const int max_filter_size = 20; const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw); + const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); auto curr_src_tag = src_d.matches_one_of_tag( - dat_tag_nxc, dat_tag_nCx16c, dat_tag_ncx); - auto curr_dst_tag - = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); + dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_ncx); + auto curr_dst_tag = diff_dst_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c); bool is_data_layout_nxc = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag); - if (mayiuse(sve_512) && is_data_layout_nxc) return status::unimplemented; + if (mayiuse(isa) && is_data_layout_nxc) return status::unimplemented; /* Optimization: when `output-width == 1' deploy a special case of the * JIT-Kernel by unrolling with regards to height instead of width for @@ -4090,10 +4103,22 @@ status_t jit_sve_conv_bwd_weights_kernel_f32::init_conf( jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0; jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.simd_w : 0; - auto dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; - auto wei_tag = with_groups - ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o) - : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o); + format_tag_t src_tag, dst_tag, wei_tag; + switch (isa) { + case sve_512: + dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; + wei_tag = with_groups + ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o) + : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o); + break; + case sve_256: + dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; + wei_tag = with_groups + ? pick(ndims - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o) + : pick(ndims - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o); + break; + default: return status::unimplemented; + } if (diff_dst_md.format_kind == format_kind::any) { CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag)); @@ -4134,7 +4159,7 @@ status_t jit_sve_conv_bwd_weights_kernel_f32::init_conf( } if (jcp.is_1stconv) { - auto src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_ncx; + src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_ncx; if (src_d.format_kind() == format_kind::any) { CHECK(memory_desc_init_by_tag(src_md, src_tag)); } else { @@ -4156,8 +4181,19 @@ status_t jit_sve_conv_bwd_weights_kernel_f32::init_conf( jcp.ver = ver_fma; jcp.ic_block = jcp.ic; - wei_tag = with_groups ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) - : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o); + switch (isa) { + case sve_512: + wei_tag = with_groups + ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) + : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o); + break; + case sve_256: + wei_tag = with_groups + ? pick(ndims - 3, gOwi8o, gOhwi8o, gOdhwi8o) + : pick(ndims - 3, Owi8o, Ohwi8o, Odhwi8o); + break; + default: return status::unimplemented; + } if (init_tag(jcp.wei_tag, diff_weights_md, diff_weights_d, wei_tag) != status::success) @@ -4165,7 +4201,16 @@ status_t jit_sve_conv_bwd_weights_kernel_f32::init_conf( jcp.nb_ic = div_up(jcp.ic, jcp.ic_block); } else { - auto src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; + switch (isa) { + case sve_512: + src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; + break; + case sve_256: + src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c; + break; + default: return status::unimplemented; + } + if (src_md.format_kind == format_kind::any) { CHECK(memory_desc_init_by_tag(src_md, src_tag)); } else if (curr_src_tag != src_tag) @@ -4179,7 +4224,7 @@ status_t jit_sve_conv_bwd_weights_kernel_f32::init_conf( jcp.ic_block = jcp.simd_w; if (ok_to_pad_channels) jcp.ic = rnd_up(jcp.ic, jcp.ic_block); jcp.nb_ic = div_up(jcp.ic, jcp.ic_block); - if (mayiuse(sve_512) + if (mayiuse(isa) && utils::everyone_is(data_type::f32, src_d.data_type(), diff_weights_d.data_type(), diff_dst_d.data_type())) { jcp.ver = ver_fma; @@ -4343,7 +4388,6 @@ void jit_sve_conv_bwd_weights_kernel_f32::balance(const jit_conv_conf_t &j, nthr_ = nthr_g_ = nthreads; return; } - nthr_g_ = j.ngroups; const int nthr = nthreads / nthr_g_; @@ -4416,9 +4460,11 @@ void jit_sve_conv_bwd_weights_kernel_f32::balance(const jit_conv_conf_t &j, template struct jit_sve_conv_fwd_kernel; template struct jit_sve_conv_fwd_kernel; template struct jit_sve_conv_bwd_data_kernel_f32; +template struct jit_sve_conv_bwd_data_kernel_f32; template struct jit_sve_conv_bwd_weights_kernel_f32; +template struct jit_sve_conv_bwd_weights_kernel_f32; } // namespace aarch64 } // namespace cpu } // namespace impl -} // namespace dnnl +} // namespace dnnl \ No newline at end of file diff --git a/src/cpu/aarch64/jit_sve_conv_kernel.hpp b/src/cpu/aarch64/jit_sve_conv_kernel.hpp index ecffaadab4c..fc817d2a1d4 100644 --- a/src/cpu/aarch64/jit_sve_conv_kernel.hpp +++ b/src/cpu/aarch64/jit_sve_conv_kernel.hpp @@ -487,7 +487,7 @@ struct jit_sve_conv_bwd_weights_kernel_f32 : public jit_generator { reg64_t reg_pre_addr_out = x26; reg64_t reg_pre_addr_ker = x26; reg64_t reg_ker_start_addr = x27; - reg64_t reg_addr_diff_input = x28; + reg64_t reg_addr_diff_input = x18; void prefetch( const std::string prfop, int level, reg64_t in, long long int ofs) { @@ -601,4 +601,4 @@ struct jit_sve_conv_bwd_weights_kernel_f32 : public jit_generator { } // namespace impl } // namespace dnnl -#endif +#endif \ No newline at end of file diff --git a/src/cpu/aarch64/jit_sve_convolution.cpp b/src/cpu/aarch64/jit_sve_convolution.cpp index 8c7870e2200..379e42889c2 100644 --- a/src/cpu/aarch64/jit_sve_convolution.cpp +++ b/src/cpu/aarch64/jit_sve_convolution.cpp @@ -1132,13 +1132,14 @@ void jit_sve_convolution_bwd_data_t; +template struct jit_sve_convolution_bwd_data_t; template status_t jit_sve_convolution_bwd_weights_t::init(engine_t *engine) { const auto &j = pd()->jcp_; - nthr_ = j.nthr; nthr_mb_ = j.nthr_mb; nthr_g_ = j.nthr_g; @@ -1151,12 +1152,13 @@ status_t jit_sve_convolution_bwd_weights_t 1) { CHECK(safe_ptr_assign( - acc_ker_, new cpu_accumulator_1d_t())); + acc_ker_, new cpu_accumulator_1d_t())); CHECK(acc_ker_->create_kernel()); } CHECK(safe_ptr_assign(reducer_bias_, - new cpu_reducer_t(pd()->reducer_bia_conf_))); + new cpu_reducer_t( + pd()->reducer_bia_conf_))); CHECK(reducer_bias_->create_kernel()); return status::success; } @@ -1514,8 +1516,8 @@ void jit_sve_convolution_bwd_weights_t; +template struct jit_sve_convolution_bwd_weights_t; } // namespace aarch64 } // namespace cpu } // namespace impl } // namespace dnnl -// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s +// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s \ No newline at end of file diff --git a/src/cpu/aarch64/jit_sve_convolution.hpp b/src/cpu/aarch64/jit_sve_convolution.hpp index de441e21e2a..16b397406d7 100644 --- a/src/cpu/aarch64/jit_sve_convolution.hpp +++ b/src/cpu/aarch64/jit_sve_convolution.hpp @@ -118,7 +118,7 @@ struct jit_sve_convolution_bwd_data_t : public primitive_t { const convolution_fwd_pd_t *hint_fwd_pd) : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} - DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", sve_512, ""), + DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", isa, ""), jit_sve_convolution_bwd_data_t); status_t init(engine_t *engine) { @@ -188,7 +188,7 @@ struct jit_sve_convolution_bwd_weights_t : public primitive_t { : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) , jcp_() {} - DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", sve_512, ""), + DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", isa, ""), jit_sve_convolution_bwd_weights_t); status_t init(engine_t *engine) { @@ -219,7 +219,8 @@ struct jit_sve_convolution_bwd_weights_t : public primitive_t { } jit_conv_conf_t jcp_; - typename cpu_reducer_t::conf_t reducer_bia_conf_; + typename cpu_reducer_t::conf_t + reducer_bia_conf_; private: void init_balancers() { @@ -262,8 +263,8 @@ struct jit_sve_convolution_bwd_weights_t : public primitive_t { int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_; jit_sve_conv_bwd_weights_kernel_f32 *kernel_; - cpu_accumulator_1d_t *acc_ker_; - cpu_reducer_t *reducer_bias_; + cpu_accumulator_1d_t *acc_ker_; + cpu_reducer_t *reducer_bias_; }; } // namespace aarch64 @@ -273,4 +274,4 @@ struct jit_sve_convolution_bwd_weights_t : public primitive_t { #endif -// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s +// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s \ No newline at end of file diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp index 83149904d7a..6d08a7f55a6 100644 --- a/src/cpu/aarch64/jit_uni_reorder.cpp +++ b/src/cpu/aarch64/jit_uni_reorder.cpp @@ -2180,6 +2180,7 @@ struct jit_single_blk_kernel_t : public jit_generator { // Register allocation xmm0~11 void gen_transpose_8x8() { + const uint64_t sveLen = get_sve_length(); constexpr int lane = 8; #if 0 @@ -2192,12 +2193,12 @@ struct jit_single_blk_kernel_t : public jit_generator { ptrue(P_ALL_ONE.b); ptrue(P_TMP.s, VL8); not_(P_TMP.b, P_ALL_ONE/T_z, P_TMP.b); - index(z0.s, 0, 1); - mov(z0.s, P_TMP/T_m, 0); - mov(z_tmp_vec[0].s, 8); - mov(z_tmp_vec[0].s, P_TMP/T_m, 0); - for(uint32_t i=1; i> &impl_list_map() CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_data_f32_t) CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_data_t) CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_data_t) + CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_data_t) CPU_INSTANCE(gemm_convolution_bwd_data_t) CPU_INSTANCE(ref_convolution_bwd_data_t) nullptr, @@ -348,6 +349,7 @@ const std::map> &impl_list_map() CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_weights_t) CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_weights_t) CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_bwd_weights_t) + CPU_INSTANCE_AARCH64(jit_sve_convolution_bwd_weights_t) CPU_INSTANCE(gemm_convolution_bwd_weights_t) CPU_INSTANCE(ref_convolution_bwd_weights_t) nullptr, @@ -781,4 +783,4 @@ const impl_list_item_t *get_convolution_impl_list( } // namespace cpu } // namespace impl -} // namespace dnnl +} // namespace dnnl \ No newline at end of file diff --git a/src/cpu/reorder/cpu_reorder_regular_s4.cpp b/src/cpu/reorder/cpu_reorder_regular_s4.cpp index 59d7e6edbbd..17bfdba758e 100644 --- a/src/cpu/reorder/cpu_reorder_regular_s4.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_s4.cpp @@ -29,6 +29,7 @@ const impl_list_map_t ®ular_s4_impl_list_map() { nullptr, }}, {{s4, f32, 0}, { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) REG_SR(s4, any, f32, any, fmt_order::any, spec::reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_u4.cpp b/src/cpu/reorder/cpu_reorder_regular_u4.cpp index 08e4784f449..60a85da4a30 100644 --- a/src/cpu/reorder/cpu_reorder_regular_u4.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_u4.cpp @@ -29,6 +29,7 @@ const impl_list_map_t ®ular_u4_impl_list_map() { nullptr, }}, {{u4, f32, 0}, { + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_matrix_B_reorder_t)) REG_SR(u4, any, f32, any, fmt_order::any, spec::reference) nullptr, }}, diff --git a/src/cpu/rnn/ref_rnn.cpp b/src/cpu/rnn/ref_rnn.cpp index 95bb5dd3592..2df2652daea 100644 --- a/src/cpu/rnn/ref_rnn.cpp +++ b/src/cpu/rnn/ref_rnn.cpp @@ -191,19 +191,19 @@ status_t dnnl::impl::cpu::_ref_rnn_common_t(rnn_.n_gates) * rnn_.dhc; - const dim_t N = rnn_.mb; + const dim_t M = rnn_.mb; + const dim_t N = static_cast(rnn_.n_gates) * rnn_.dhc; const dim_t K = rnn_.slc; - const dim_t LDA = rnn_.weights_layer_ld; - const dim_t LDB1 = rnn_.src_layer_ld_; - const dim_t LDB2 = rnn_.ws_states_layer_ld; - const dim_t LDB3 = rnn_.dst_iter_ld_; + const dim_t LDA1 = rnn_.src_layer_ld_; + const dim_t LDA2 = rnn_.ws_states_layer_ld; + const dim_t LDA3 = rnn_.dst_iter_ld_; + const dim_t LDB = rnn_.weights_layer_ld; const dim_t LDC = rnn_.scratch_gates_ld; const bool do_sum = false; - if (LDB1 >= K) + if (LDA1 >= K) CHECK(init_matmul_pd( - matmul_layer_1_pd_, M, N, K, LDA, LDB1, LDC, do_sum)); - if (LDB2 >= K && LDB2 != LDB1) + matmul_layer_1_pd_, M, N, K, LDA1, LDB, LDC, do_sum)); + if (LDA2 >= K && LDA2 != LDA1) CHECK(init_matmul_pd( - matmul_layer_2_pd_, M, N, K, LDA, LDB2, LDC, do_sum)); - if (LDB3 >= K && !utils::one_of(LDB3, LDB1, LDB2)) + matmul_layer_2_pd_, M, N, K, LDA2, LDB, LDC, do_sum)); + if (LDA3 >= K && !utils::one_of(LDA3, LDA1, LDA2)) CHECK(init_matmul_pd( - matmul_layer_3_pd_, M, N, K, LDA, LDB3, LDC, do_sum)); + matmul_layer_3_pd_, M, N, K, LDA3, LDB, LDC, do_sum)); } { // init iter matmuls - const dim_t M = static_cast(rnn_.dhc) + const dim_t M = rnn_.mb; + const dim_t N = static_cast(rnn_.dhc) * (rnn_.n_gates - rnn_.is_orig_gru); - const dim_t N = rnn_.mb; const dim_t K = rnn_.sic; - const dim_t LDA = rnn_.weights_iter_ld; - const dim_t LDB1 = rnn_.src_iter_ld_; - const dim_t LDB2 = rnn_.ws_states_iter_ld; - const dim_t LDB3 = rnn_.dst_layer_ld_; + const dim_t LDA1 = rnn_.src_iter_ld_; + const dim_t LDA2 = rnn_.ws_states_iter_ld; + const dim_t LDA3 = rnn_.dst_layer_ld_; + const dim_t LDB = rnn_.weights_iter_ld; const dim_t LDC = rnn_.is_lbr ? rnn_.ws_gates_ld : rnn_.scratch_gates_ld; const bool do_sum = !rnn_.is_lbr; - if (LDB1 >= K) + if (LDA1 >= K) CHECK(init_matmul_pd( - matmul_iter_1_pd_, M, N, K, LDA, LDB1, LDC, do_sum)); - if (LDB2 >= K && LDB2 != LDB1) + matmul_iter_1_pd_, M, N, K, LDA1, LDB, LDC, do_sum)); + if (LDA2 >= K && LDA2 != LDA1) CHECK(init_matmul_pd( - matmul_iter_2_pd_, M, N, K, LDA, LDB2, LDC, do_sum)); - if (LDB3 >= K && !utils::one_of(LDB3, LDB1, LDB2)) + matmul_iter_2_pd_, M, N, K, LDA2, LDB, LDC, do_sum)); + if (LDA3 >= K && !utils::one_of(LDA3, LDA1, LDA2)) CHECK(init_matmul_pd( - matmul_iter_3_pd_, M, N, K, LDA, LDB3, LDC, do_sum)); + matmul_iter_3_pd_, M, N, K, LDA3, LDB, LDC, do_sum)); if (rnn_.is_orig_gru) { - const dim_t M_part2 = rnn_.dhc; - const dim_t LDB1 = rnn_.ws_states_layer_ld; - const dim_t LDB2 = rnn_.ws_states_iter_ld; - const dim_t LDB3 = rnn_.dst_layer_ld_; - const dim_t LDB4 = rnn_.dst_iter_ld_; - if (LDB1 >= K) - CHECK(init_matmul_pd(matmul_part2_1_pd_, M_part2, N, K, LDA, - LDB1, LDC, do_sum)); - if (LDB2 >= K && LDB2 != LDB1) - CHECK(init_matmul_pd(matmul_part2_2_pd_, M_part2, N, K, LDA, - LDB2, LDC, do_sum)); - if (LDB3 >= K && !utils::one_of(LDB3, LDB1, LDB2)) - CHECK(init_matmul_pd(matmul_part2_3_pd_, M_part2, N, K, LDA, - LDB3, LDC, do_sum)); - if (LDB4 >= K && !utils::one_of(LDB4, LDB1, LDB2, LDB3)) - CHECK(init_matmul_pd(matmul_part2_4_pd_, M_part2, N, K, LDA, - LDB4, LDC, do_sum)); + const dim_t N_part2 = rnn_.dhc; + const dim_t LDA1 = rnn_.ws_states_layer_ld; + const dim_t LDA2 = rnn_.ws_states_iter_ld; + const dim_t LDA3 = rnn_.dst_layer_ld_; + const dim_t LDA4 = rnn_.dst_iter_ld_; + if (LDA1 >= K) + CHECK(init_matmul_pd(matmul_part2_1_pd_, M, N_part2, K, + LDA1, LDB, LDC, do_sum)); + if (LDA2 >= K && LDA2 != LDA1) + CHECK(init_matmul_pd(matmul_part2_2_pd_, M, N_part2, K, + LDA2, LDB, LDC, do_sum)); + if (LDA3 >= K && !utils::one_of(LDA3, LDA1, LDA2)) + CHECK(init_matmul_pd(matmul_part2_3_pd_, M, N_part2, K, + LDA3, LDB, LDC, do_sum)); + if (LDA4 >= K && !utils::one_of(LDA4, LDA1, LDA2, LDA3)) + CHECK(init_matmul_pd(matmul_part2_4_pd_, M, N_part2, K, + LDA4, LDB, LDC, do_sum)); } } } @@ -312,7 +312,6 @@ _ref_rnn_common_t::pd_t::init_brgemm( = this->desc()->weights_iter_desc.data_type; const data_type_t weights_layer_dt = this->desc()->weights_layer_desc.data_type; - bool is_f32 = everyone_is( data_type::f32, src_layer_dt, weights_iter_dt, weights_layer_dt); bool is_impl_bf16 = everyone_is(data_type::bf16, src_type, weights_type); @@ -840,8 +839,9 @@ rnn_matmul_sig((_ref_rnn_common_tpd()->dst_md(), mem_flag, (void *)(c_)); exec_args_t matmul_args; - matmul_args[DNNL_ARG_SRC] = {&src_mem, true}; - matmul_args[DNNL_ARG_WEIGHTS] = {&wei_mem, true}; + // Note Matmul src and wei may not directly map to RNN primitive src and wei + matmul_args[DNNL_ARG_SRC] = {&wei_mem, true}; + matmul_args[DNNL_ARG_WEIGHTS] = {&src_mem, true}; matmul_args[DNNL_ARG_DST] = {&dst_mem, false}; exec_ctx_t matmul_ctx(ctx, std::move(matmul_args)); diff --git a/src/cpu/rnn/ref_rnn.hpp b/src/cpu/rnn/ref_rnn.hpp index bb24f5ad1a5..a479867bd26 100644 --- a/src/cpu/rnn/ref_rnn.hpp +++ b/src/cpu/rnn/ref_rnn.hpp @@ -136,7 +136,8 @@ struct _ref_rnn_common_t : public primitive_t { using namespace dnnl::impl::cpu::x64; return rnn_.is_brgemm ? JIT_IMPL_NAME_HELPER("brgemm:", rnn_.brgemm_isa, "") - : "ref"; + : rnn_.use_matmul ? "ref+matmul" + : "ref"; #else return "ref"; #endif diff --git a/src/cpu/rnn/rnn_utils.hpp b/src/cpu/rnn/rnn_utils.hpp index 1cf87f9e3a2..0bd61ba9365 100644 --- a/src/cpu/rnn/rnn_utils.hpp +++ b/src/cpu/rnn/rnn_utils.hpp @@ -861,8 +861,14 @@ bool init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, #if DNNL_X64 && IMPLICATION( rnn.is_cell_dt_bf16(), !x64::mayiuse(x64::avx512_core)) -#endif + && IMPLICATION(rnn.is_cell_dt_f32() || rnn.is_cell_dt_int8(), + x64::mayiuse(x64::avx2) + && utils::one_of(rd.cell_kind, + alg_kind::vanilla_gru, + alg_kind::vanilla_augru)); +#else && !rnn.is_cell_dt_f32() && !rnn.is_cell_dt_int8(); +#endif /* Decide which gemm implementation to use: packed/nonpacked jit/cblas * and if to merge gemm across iterations */ diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp index 91e9ac98c06..1d4422a81fb 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp @@ -4106,7 +4106,7 @@ void jit_avx512_core_amx_bwd_weights_kernel_t::compute_full_spat_loop( assert(full_spat_opt_working_set_size < full_spat_max_working_set_size); while (working_set_size > full_spat_opt_working_set_size - && h_block_size >= min_h_block_size) { + && h_block_size >= min_h_block_size && h_block_size >= 2) { for (int i = 2; i <= h_block_size; i++) if (i == h_block_size) h_block_size = h_block_size / 2; @@ -4253,7 +4253,6 @@ void jit_avx512_core_amx_bwd_weights_kernel_t::compute_full_spat_loop( // restore the zeroing flag (it will be cleared after the end of // emit_kh_kw_loop, but we may need it until then) or_(reg_ker, 1); - jmp(kh_loop_end, T_NEAR); L(skip_ker_zeroing); add(reg_ker, get_kernel_offset(0, jcp.kw)); diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index 54c68fcb6d9..3682e6409e1 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -1831,7 +1831,7 @@ status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf_default; jcp.brgemm_bd_loop_innermost = false; - if (!jcp.wei_plain && jcp.prop_kind != prop_kind::backward_weights) { + if (!jcp.wei_plain) { // fast check data layout before spending time for blocking selection format_tag_t src_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc); const bool any_eligible = is_any_eligible(jcp); diff --git a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp index 99f9c1969ed..56ee8aff8ca 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp @@ -2868,7 +2868,7 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, opmask_t kFFFF = k6; opmask_t kTail_int4 = k5; opmask_t kAAAA = k4; - opmask_t kSign = k3; + opmask_t k5555 = k3; reg64_t reg_src = rax; reg64_t reg_tr_src = rbx; @@ -2893,9 +2893,6 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, Vmm vmm_tmp = Vmm(1); // used only for avx2_vnni_2 Vmm vmm_zp_b_shift = Vmm(2); Vmm vmm_permd = Vmm(3); - Vmm vmm_int4_mask = Vmm(4); - Vmm vmm_sign_bit = Vmm(5); - Vmm vmm_sign_mask = Vmm(6); void kmovx(Opmask k, unsigned w) { if (!isa_has_masks(conf_->isa)) return; @@ -2933,8 +2930,7 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, return vmm; } } - void load_int( - const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail = false); + void load_data(const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail); void copy_block(int nrows, int ncolumns, bool n_tail); void copy_2x32(int nrows, int ncolumns); void init_masks(); @@ -2942,15 +2938,16 @@ struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, }; template -void jit_brgemm_matmul_copy_b_bf16_t::load_int( +void jit_brgemm_matmul_copy_b_bf16_t::load_data( const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail) { const auto vmm = maybe_mask(vmm_in, is_tail); const auto vmm_lower = Vmm_lower_t(vmm.getIdx()); - const auto is_s4 = conf_->orig_wei_dt == data_type::s4; MAYBE_UNUSED(vmm_lower); - MAYBE_UNUSED(is_s4); switch (conf_->orig_wei_dt) { + case data_type::f32: uni_vmovups(vmm, op); break; + case data_type::f16: + case data_type::bf16: vmovdqu16(vmm, op); break; case data_type::s8: uni_vpmovsxbd(vmm, op); break; case data_type::u8: uni_vpmovzxbd(vmm, op); break; // For int4, we see two int4 as one int8 and extend them int32 @@ -2958,17 +2955,20 @@ void jit_brgemm_matmul_copy_b_bf16_t::load_int( // bytes of vmm, then permute them into correct order // Finally, we process the extend bytes for s4/u4 accordingly case data_type::s4: + uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), op); + copy_half_int4(vmm_in, vmm_lower); + vpermd(vmm_in, vmm_permd, vmm_in); + uni_vpslld(vmm_in | k5555, vmm_in, 28); + vpsrad(vmm_in | k5555, vmm_in, 28); + vpsrad(vmm_in | kAAAA, vmm_in, 4); + break; case data_type::u4: - if (is_s4) - uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), op); - else - uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op); + uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op); copy_half_int4(vmm_in, vmm_lower); vpermd(vmm_in, vmm_permd, vmm_in); + uni_vpslld(vmm_in | k5555, vmm_in, 28); + vpsrld(vmm_in | k5555, vmm_in, 28); vpsrld(vmm_in | kAAAA, vmm_in, 4); - if (is_s4) vptestmd(kSign, vmm_in, vmm_sign_bit); - vpandd(vmm_in, vmm_in, vmm_int4_mask); - if (is_s4) vpord(vmm_in | kSign, vmm_in, vmm_sign_mask); break; default: assert(!"unsupported data type"); } @@ -2988,9 +2988,7 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32(int nrows, int ncolumns) { } static constexpr int blk_sz = k_blk_step; - const int reserved_regs = !is_src_int4 - ? (req_zp_b_shift ? 3 : 2) - : (conf_->orig_wei_dt == data_type::s4 ? 7 : 5); + const int reserved_regs = is_src_int4 ? 4 : req_zp_b_shift ? 3 : 2; const int max_isa_regs = isa_num_vregs(conf_->isa); const int max_regs_available = max_isa_regs - reserved_regs; const int max_unroll = max_regs_available / blk_sz; @@ -3007,9 +3005,9 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32(int nrows, int ncolumns) { auto src_reg = get_vmm(blk, k % k_blk_step); const bool is_tail = ncolumns - n < n_blk_step; auto src_load = maybe_mask(src_reg, is_tail); - const auto factor = is_src_int4 ? 2 : 1; + const auto typesize_scale = is_src_int4 ? 2 : 1; const auto offset = (is_dynamic_stride ? 0 : k * src_stride) - + ((n * typesize) / factor); + + ((n * typesize) / typesize_scale); const auto reg_src_load = is_dynamic_stride && k % 2 != 0 ? reg_src_load_1 : reg_src; auto load_addr = maybe_EVEX_compress_addr(reg_src_load, offset); @@ -3019,23 +3017,24 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32(int nrows, int ncolumns) { else uni_vmovups(src_load, load_addr); } else { - if (conf_->is_bf32) - uni_vmovups(src_load, load_addr); - else if (conf_->is_bf16_with_int_wei) { - load_int(src_reg, load_addr, is_tail); - if (req_zp_b_shift) - uni_vpsubd(src_load, src_load, vmm_zp_b_shift); - uni_vcvtdq2ps(src_load, src_load); - if (req_apply_scales) { - const auto scales_offset - = (is_dynamic_stride ? 0 : k * scales_N_stride) - + n * scales_typesize; - const auto scales_addr = maybe_EVEX_compress_addr( - reg_scales, scales_offset); - uni_vmulps(src_load, src_load, scales_addr); - } - } else - vmovdqu16(src_load, load_addr); + load_data(src_reg, load_addr, is_tail); + } + + if (utils::one_of(conf_->orig_wei_dt, data_type::s8, data_type::u8, + data_type::s4, data_type::u4)) { + if (req_zp_b_shift) uni_vpsubd(src_load, src_load, vmm_zp_b_shift); + uni_vcvtdq2ps(src_load, src_load); + if (req_apply_scales) { + const auto scales_offset + = (is_dynamic_stride ? 0 : k * scales_N_stride) + + n * scales_typesize; + const auto scales_addr + = maybe_EVEX_compress_addr(reg_scales, scales_offset); + uni_vmulps(src_load, src_load, scales_addr); + } + + if (conf_->wei_dt == data_type::f16) + vcvtps2phx(Vmm_lower_t(src_reg.getIdx()), src_reg); } }; @@ -3126,18 +3125,7 @@ void jit_brgemm_matmul_copy_b_bf16_t::init_masks() { vmovdqa32(vmm_permd, ptr[reg_tmp]); kmovx(kAAAA, 0xaaaa); - - const auto reg32_scratch = reg_tmp.cvt32(); - mov(reg32_scratch, 0xf); - vpbroadcastd(vmm_int4_mask, reg32_scratch); - - if (conf_->orig_wei_dt == data_type::s4) { - mov(reg32_scratch, 0x8); - vpbroadcastd(vmm_sign_bit, reg32_scratch); - - mov(reg32_scratch, 0xfffffff8); - vpbroadcastd(vmm_sign_mask, reg32_scratch); - } + kmovx(k5555, 0x5555); } } } @@ -3303,7 +3291,9 @@ struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t, , jit_generator(jit_name()) , dt_in_(conf->orig_wei_dt) , simd_w_(vreg_traits::vlen / sizeof(float)) + , is_src_int4_(one_of(conf->orig_wei_dt, data_type::s4, data_type::u4)) , typesize_in_(types::data_type_size(dt_in_)) + , typesize_scale_(is_src_int4_ ? 2 : 1) , src_stride_(conf_->copy_B_wei_stride) , tr_src_stride_(conf_->LDB * typesize_out_) {} @@ -3314,15 +3304,19 @@ struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t, using reg64_t = const Xbyak::Reg64; using reg32_t = const Xbyak::Reg32; using opmask_t = const Xbyak::Opmask; + using Vmm_lower_t = typename vreg_traits::Vmm_lower_t; const data_type_t dt_in_; const int simd_w_; - const size_t typesize_in_; + const bool is_src_int4_; + const size_t typesize_in_, typesize_scale_; const size_t typesize_out_ = sizeof(float); dim_t src_stride_, tr_src_stride_; opmask_t kTail = k7; opmask_t kFFFF = k6; + opmask_t k5555 = k5; + opmask_t kAAAA = k4; reg64_t reg_src = rax; reg64_t reg_tr_src = rbx; @@ -3335,6 +3329,7 @@ struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t, Vmm vmm_zero = Vmm(0); Vmm vmm_permw = Vmm(1); + Vmm vmm_permd = Vmm(2); Ymm ymm_tail_mask = ymm1; inline void kmovw(Opmask k, unsigned w) { @@ -3342,16 +3337,71 @@ struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t, mov(regw_tmp, w); jit_generator::kmovd(k, regw_tmp); } + void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) { + vinserti64x4(zmm, zmm, ymm_half, 1); + } + void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) { + vinserti128(ymm, ymm, xmm_half, 1); + } + Vmm_lower_t maybe_mask(Vmm_lower_t vmm_lower, bool is_tail) { + assert(is_src_int4_); + return is_tail && isa_has_masks(conf_->isa) ? vmm_lower | kTail | T_z + : vmm_lower; + } + Vmm maybe_mask(Vmm vmm, bool is_tail) { + return is_tail && isa_has_masks(conf_->isa) ? vmm | kTail | T_z : vmm; + } + void load_data(const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail); void copy_16_x_n_block(int nrows, int ncolumns); void compute_k_loop(int ncolumns); void generate() override; }; +template +void jit_brgemm_matmul_copy_b_f32_t::load_data( + const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail) { + const auto vmm = maybe_mask(vmm_in, is_tail); + const auto vmm_lower = Vmm_lower_t(vmm.getIdx()); + MAYBE_UNUSED(vmm_lower); + + switch (dt_in_) { + case data_type::f32: uni_vmovups(vmm, op); break; + case data_type::f16: vcvtph2psx(vmm, op); break; + case data_type::s8: uni_vpmovsxbd(vmm, op); break; + case data_type::u8: uni_vpmovzxbd(vmm, op); break; + // For int4, we see two int4 as one int8 and extend them int32 + // low half stores in lower bytes of vmm and high half in higher + // bytes of vmm, then permute them into correct order + // Finally, we process the extend bytes for s4/u4 accordingly + case data_type::s4: + uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), op); + copy_half_int4(vmm_in, vmm_lower); + vpermd(vmm_in, vmm_permd, vmm_in); + uni_vpslld(vmm_in | k5555, vmm_in, 28); + vpsrad(vmm_in | k5555, vmm_in, 28); + vpsrad(vmm_in | kAAAA, vmm_in, 4); + break; + case data_type::u4: + uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op); + copy_half_int4(vmm_in, vmm_lower); + vpermd(vmm_in, vmm_permd, vmm_in); + uni_vpslld(vmm_in | k5555, vmm_in, 28); + vpsrld(vmm_in | k5555, vmm_in, 28); + vpsrld(vmm_in | kAAAA, vmm_in, 4); + break; + default: assert(!"unsupported data type"); + } + + if (one_of(dt_in_, data_type::s8, data_type::u8, data_type::s4, + data_type::u4)) + uni_vcvtdq2ps(vmm_in, vmm_in); +} + template void jit_brgemm_matmul_copy_b_f32_t::copy_16_x_n_block( int nrows, int ncolumns) { const int max_isa_regs = isa_num_vregs(conf_->isa); - constexpr int reserved_regs = 2; + const int reserved_regs = is_src_int4_ ? 3 : 2; const int max_regs_available = max_isa_regs - reserved_regs; auto get_vmm = [max_regs_available, reserved_regs](int reg_idx) { @@ -3364,24 +3414,18 @@ void jit_brgemm_matmul_copy_b_f32_t::copy_16_x_n_block( auto load = [this, get_vmm, ncolumns](int blk, int k, int n) { auto src_vmm = get_vmm(blk); const bool is_tail = ncolumns - n < simd_w_; - const opmask_t current_mask = is_tail ? kTail : kFFFF; - auto src_vmm_m = isa_has_masks(conf_->isa) - ? src_vmm | current_mask | T_z - : src_vmm; - auto addr = maybe_EVEX_compress_addr( - reg_src, k * src_stride_ + n * typesize_in_); + auto addr = maybe_EVEX_compress_addr(reg_src, + k * src_stride_ + ((n * typesize_in_) / typesize_scale_)); if (is_tail && !isa_has_masks(conf_->isa)) vmaskmovps(src_vmm, ymm_tail_mask, addr); - else if (dt_in_ == data_type::f16) - vcvtph2psx(src_vmm_m, addr); else - uni_vmovups(src_vmm_m, addr); + load_data(src_vmm, addr, is_tail); }; const int columns_tail = ncolumns % simd_w_; if (columns_tail < simd_w_) { if (isa_has_masks(conf_->isa)) { - const auto tail_mask = (1 << columns_tail) - 1; + const auto tail_mask = (1 << (columns_tail / typesize_scale_)) - 1; kmovw(kTail, tail_mask); } else { init_f32_avx2_mask_ymm(ymm_tail_mask, reg_tmp, columns_tail); @@ -3445,6 +3489,15 @@ void jit_brgemm_matmul_copy_b_f32_t::generate() { mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]); mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]); kmovw(kFFFF, 0xffff); // 1111111111111111 + if (is_src_int4_) { + alignas(64) static constexpr const uint32_t int4_permute[16] + = {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; + mov(reg_tmp, reinterpret_cast(int4_permute)); + vmovdqa32(vmm_permd, ptr[reg_tmp]); + + kmovw(kAAAA, 0xaaaa); + kmovw(k5555, 0x5555); + } Label done; if (conf_->N_tail > 0) { @@ -3484,6 +3537,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t conf_->has_zero_point_a || conf_->s8s8_compensation_required) , is_bf32_(conf->is_bf32) , is_bf16_with_int_wei_(conf->is_bf16_with_int_wei) + , is_src_int4_(one_of(conf->orig_wei_dt, data_type::s4, data_type::u4)) , req_cvtps2bf16_(conf->is_bf32 || conf->is_bf16_with_int_wei) , req_zp_comp_(conf_->has_zero_point_a) , req_s8s8_comp_(conf_->s8s8_compensation_required) @@ -3496,11 +3550,13 @@ struct jit_brgemm_matmul_copy_b_transposed_t - (avx512_core_dot_product_ ? 8 : (do_compute_compensation_ ? 6 + : is_src_int4_ ? 2 : req_zp_b_shift_ ? 1 : 0))) , src_stride_(conf_->copy_B_wei_stride) , tr_src_stride_(conf_->LDB * vnni_granularity_ * tr_typesize_) , scales_K_stride_(conf_->K * scales_typesize_) + , typesize_scale_(is_src_int4_ ? 2 : 1) , is_dynamic_N_(conf->is_runtime_N) {} void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } @@ -3510,6 +3566,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t using reg64_t = const Xbyak::Reg64; using reg32_t = const Xbyak::Reg32; using opmask_t = const Xbyak::Opmask; + using Vmm_lower_t = typename vreg_traits::Vmm_lower_t; static constexpr bool is_ymm_ = std::is_same::value; static constexpr cpu_isa_t isa_ = is_ymm_ ? avx2 : avx512_core; @@ -3527,6 +3584,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t const bool do_compute_compensation_; const bool is_bf32_; const bool is_bf16_with_int_wei_; + const bool is_src_int4_; const bool req_cvtps2bf16_; const bool req_zp_comp_; const bool req_s8s8_comp_; @@ -3535,7 +3593,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t const bool avx512_core_dot_product_; const int max_tmp_idx; - const dim_t src_stride_, tr_src_stride_, scales_K_stride_; + const dim_t src_stride_, tr_src_stride_, scales_K_stride_, typesize_scale_; const bool is_dynamic_N_; opmask_t k3333 = k1; @@ -3545,6 +3603,8 @@ struct jit_brgemm_matmul_copy_b_transposed_t opmask_t k0F0F = k5; opmask_t kF0F0 = k6; opmask_t kTail = k7; + // reuse k7 for int4 and restore the value after use + opmask_t kTail_int4 = k7; reg64_t reg_src_base = rax; reg64_t reg_tr_src_base = rbx; @@ -3578,6 +3638,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t Vmm vmm_dot_product_temp = Vmm(max_vmm_regs_ - 8); Vmm vmm_zp_b_val = Vmm(max_vmm_regs_ - 1); + Vmm vmm_permd = Vmm(max_vmm_regs_ - 2); void kmovw(Opmask k, unsigned w) { mov(regw_tmp, w); @@ -3601,6 +3662,28 @@ struct jit_brgemm_matmul_copy_b_transposed_t return Vmm(n_blk_step_ + i); } + void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) { + vinserti64x4(zmm, zmm, ymm_half, 1); + } + + void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) { + vinserti128(ymm, ymm, xmm_half, 1); + } + + Vmm_lower_t maybe_mask(Vmm_lower_t vmm_lower, bool is_tail) { + assert(is_src_int4_); + return isa_has_masks(conf_->isa) && is_tail + ? vmm_lower | kTail_int4 | T_z + : vmm_lower; + } + + Vmm maybe_mask(Vmm vmm, bool is_tail) { + return isa_has_masks(conf_->isa) && is_tail ? vmm | kTail | T_z : vmm; + } + + void init_tail_mask(const int columns_tail, const bool use_int4_mask); + void load_int(const Vmm vmm_in, const Xbyak::Operand &op, + const int columns_tail, bool is_tail); void copy_row_x_col(int nrows, int ncolumns); void compute_K_loop(bool is_N_tail, int curr_K_tail, bool is_first_K_iter, bool is_last_K_iter); @@ -3627,25 +3710,70 @@ struct jit_brgemm_matmul_copy_b_transposed_t }; template -void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( - int nrows, int ncolumns) { - assert(nrows >= 0 && nrows <= n_blk_step_ && ncolumns >= 0 - && ncolumns <= k_blk_step_); - if (!nrows) return; - - const int columns_tail = ncolumns - % (req_cvtps2bf16_ ? req_cvt_bf16_k_blk_step_ : k_blk_step_); +void jit_brgemm_matmul_copy_b_transposed_t::init_tail_mask( + const int columns_tail, const bool use_int4_mask) { + assert(IMPLICATION(is_src_int4_, use_int4_mask)); if (columns_tail > 0) { const int dt_step = req_cvtps2bf16_ || conf_->isa == avx512_core_fp16 ? 1 : typesize_; - const auto tail_mask - = size_t(((size_t)1 << dt_step * columns_tail) - 1); + const auto tail_mask = use_int4_mask + ? size_t(((size_t)1 << (dt_step * columns_tail) / 2) - 1) + : size_t(((size_t)1 << dt_step * columns_tail) - 1); if (req_cvtps2bf16_) kmovw(kTail, tail_mask); else kmovq(kTail, tail_mask); } +} + +template +void jit_brgemm_matmul_copy_b_transposed_t::load_int(const Vmm vmm_in, + const Xbyak::Operand &op, int columns_tail, bool is_tail) { + const auto vmm = maybe_mask(vmm_in, is_tail); + const auto vmm_lower = Vmm_lower_t(vmm.getIdx()); + MAYBE_UNUSED(vmm_lower); + if (is_src_int4_) init_tail_mask(columns_tail, true); + + switch (conf_->orig_wei_dt) { + case data_type::s8: uni_vpmovsxbd(vmm, op); break; + case data_type::u8: uni_vpmovzxbd(vmm, op); break; + // For int4, we see two int4 as one int8 and extend them int32 + // low half stores in lower bytes of vmm and high half in higher + // bytes of vmm, then permute them into correct order + // Finally, we process the extend bytes for s4/u4 accordingly + case data_type::s4: + uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), op); + copy_half_int4(vmm_in, vmm_lower); + vpermd(vmm_in, vmm_permd, vmm_in); + uni_vpslld(vmm_in | k5555, vmm_in, 28); + vpsrad(vmm_in | k5555, vmm_in, 28); + vpsrad(vmm_in | kAAAA, vmm_in, 4); + break; + case data_type::u4: + uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op); + copy_half_int4(vmm_in, vmm_lower); + vpermd(vmm_in, vmm_permd, vmm_in); + uni_vpslld(vmm_in | k5555, vmm_in, 28); + vpsrld(vmm_in | k5555, vmm_in, 28); + vpsrld(vmm_in | kAAAA, vmm_in, 4); + break; + default: assert(!"unsupported data type"); + } + // restore the tail_mask + if (is_src_int4_) init_tail_mask(columns_tail, false); +} + +template +void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( + int nrows, int ncolumns) { + assert(nrows >= 0 && nrows <= n_blk_step_ && ncolumns >= 0 + && ncolumns <= k_blk_step_); + if (!nrows) return; + + const int columns_tail = ncolumns + % (req_cvtps2bf16_ ? req_cvt_bf16_k_blk_step_ : k_blk_step_); + init_tail_mask(columns_tail, false); auto load2bf16 = [this, nrows, columns_tail, ncolumns]( int i, int base_idx) { @@ -3676,10 +3804,8 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( if (is_bf32_) vmovups(zmm_src, addr); else if (is_bf16_with_int_wei_) { - if (conf_->orig_wei_dt == data_type::s8) - vpmovsxbd(zmm_src, addr); - else - vpmovzxbd(zmm_src, addr); + load_int(src_reg, addr, columns_tail, + columns_tail > 0 && ncolumns < req_cvt_bf16_k_blk_step_); if (req_zp_b_shift_) vpsubd(zmm_src, zmm_src, vmm_zp_b_val); vcvtdq2ps(zmm_src, zmm_src); if (req_apply_scales_) { @@ -3696,16 +3822,17 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( auto zmm_src_next = columns_tail > 0 ? src_reg_next | kTail | T_z : src_reg_next; const auto next_addr = EVEX_compress_addr(reg_src, - i * src_stride_ + req_cvt_bf16_k_blk_step_ * typesize_); + i * src_stride_ + + (req_cvt_bf16_k_blk_step_ * typesize_) + / typesize_scale_); if (is_bf32_) vmovups(zmm_src_next, next_addr); else if (is_bf16_with_int_wei_) { - if (conf_->orig_wei_dt == data_type::s8) - vpmovsxbd(zmm_src_next, next_addr); - else - vpmovzxbd(zmm_src_next, next_addr); + load_int(src_reg_next, next_addr, columns_tail, + columns_tail > 0); if (req_zp_b_shift_) vpsubd(zmm_src_next, zmm_src_next, vmm_zp_b_val); + vcvtdq2ps(zmm_src_next, zmm_src_next); if (req_apply_scales_) { const auto scales_next_addr = EVEX_compress_addr(reg_scales, @@ -3761,7 +3888,8 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( // If compensation compute is required - use tmp(0) ... tmp(7) // to not spoil reserved registers' values const int tmp_corr_idx - = (do_compute_compensation_ || req_zp_b_shift_) * base_idx; + = (is_src_int4_ || do_compute_compensation_ || req_zp_b_shift_) + * base_idx; // swap 1 if (req_cvtps2bf16_) { @@ -4022,7 +4150,7 @@ void jit_brgemm_matmul_copy_b_transposed_t::compute_K_loop(bool is_N_tail, L(K_loop); copy_row_x_col(nrows, k_blk_step_); - add(reg_src, k_blk_step_ * typesize_); + add(reg_src, (k_blk_step_ * typesize_) / typesize_scale_); add(reg_tr_src, k_blk_step_ / vnni_granularity_ * tr_src_stride_); if (req_apply_scales_) add(reg_scales, k_blk_step_ * scales_typesize_); @@ -4122,6 +4250,12 @@ void jit_brgemm_matmul_copy_b_transposed_t::generate() { kmovw(k0F0F, 0x0f0f); kmovw(kF0F0, 0xf0f0); } + if (is_src_int4_ && is_superset(conf_->isa, avx512_core)) { + alignas(64) static constexpr const uint32_t int4_permute[16] + = {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; + mov(regq_tmp, reinterpret_cast(int4_permute)); + vmovdqa32(vmm_permd, ptr[regq_tmp]); + } const dim_t N_chunk_elems = conf_->N_chunk_elems; assert(N_chunk_elems % n_blk_step_ == 0 || N_chunk_elems == conf_->N); @@ -4221,13 +4355,16 @@ struct jit_brgemm_matmul_copy_b_cvt_bf16_t : public jit_brgemm_matmul_copy_b_t, , typesize_(conf->b_dt_sz) , tr_typesize_(conf->tr_b_dt_sz) , scales_typesize_(sizeof(float)) - , src_stride_(conf->LDB * k_blk_step * typesize_) + , is_src_int4_(one_of(conf->orig_wei_dt, data_type::s4, data_type::u4)) + , typesize_scale_(is_src_int4_ ? 2 : 1) + , src_stride_((conf->LDB * k_blk_step * typesize_) / typesize_scale_) , tr_src_stride_(conf_->LDB * k_blk_step * tr_typesize_) , scales_N_stride_(conf->N * scales_typesize_) , req_zp_b_shift_( conf_->has_zero_point_b && conf_->with_wei_decompression) , req_apply_scales_(conf_->apply_scales_in_buffer_b) , reserved_regs_(req_apply_scales_ ? 5 + : is_src_int4_ ? 2 : req_zp_b_shift_ ? 1 : 0) {} @@ -4238,18 +4375,22 @@ struct jit_brgemm_matmul_copy_b_cvt_bf16_t : public jit_brgemm_matmul_copy_b_t, using reg64_t = const Xbyak::Reg64; using reg32_t = const Xbyak::Reg32; using opmask_t = const Xbyak::Opmask; + using Vmm_lower_t = typename vreg_traits::Vmm_lower_t; using zmm = const Xbyak::Zmm; using ymm = const Xbyak::Ymm; enum { k_blk_step = 2, n_blk_step = 16 }; const int typesize_, tr_typesize_, scales_typesize_; - const dim_t src_stride_, tr_src_stride_, scales_N_stride_; + const bool is_src_int4_; + const dim_t typesize_scale_, src_stride_, tr_src_stride_, scales_N_stride_; const bool req_zp_b_shift_; const bool req_apply_scales_; const int reserved_regs_; opmask_t kTail = k7; opmask_t kFFFF = k6; + opmask_t kAAAA = k5; + opmask_t k5555 = k4; reg64_t reg_src = rax; reg64_t reg_tr_src = rbx; @@ -4258,13 +4399,17 @@ struct jit_brgemm_matmul_copy_b_cvt_bf16_t : public jit_brgemm_matmul_copy_b_t, reg64_t reg_N_blk = r9; reg64_t reg_scales = r10; reg64_t reg_tmp = r11; + reg32_t regw_tmp = r11d; Vmm vmm_zp_b_val = Vmm(0); - Vmm vmm_scales0 = Vmm(1); - Vmm vmm_scales1 = Vmm(2); - Vmm vmm_permd = Vmm(3); + Vmm vmm_permd = Vmm(1); + Vmm vmm_scales0 = Vmm(2); + Vmm vmm_scales1 = Vmm(3); Vmm vmm_tmp = Vmm(4); + void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) { + vinserti64x4(zmm, zmm, ymm_half, 1); + } Vmm maybe_mask(Vmm vmm, bool is_tail) { if (isa_has_masks(conf_->isa)) { return is_tail ? vmm | kTail | T_z : vmm | kFFFF | T_z; @@ -4284,6 +4429,7 @@ struct jit_brgemm_matmul_copy_b_cvt_bf16_t : public jit_brgemm_matmul_copy_b_t, } void init_masks(); + void load_int(const Vmm vmm_in, const Xbyak::Operand &op); void get_scales(const int blk, const int k, const int n, const bool is_n_tail, const bool is_k_tail); void copy_block(const int nrows, const int ncolumns); @@ -4300,6 +4446,44 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t::init_masks() { mov(reg_tmp, reinterpret_cast(bf16_vnni_permute)); vmovdqa32(vmm_permd, ptr[reg_tmp]); + + mov(regw_tmp, 0x5555); + kmovw(k5555, regw_tmp); + mov(regw_tmp, 0xaaaa); + kmovw(kAAAA, regw_tmp); + } +} + +template +void jit_brgemm_matmul_copy_b_cvt_bf16_t::load_int( + const Vmm vmm_in, const Xbyak::Operand &op) { + const auto vmm_lower = Vmm_lower_t(vmm_in.getIdx()); + MAYBE_UNUSED(vmm_lower); + + switch (conf_->orig_wei_dt) { + case data_type::s8: uni_vpmovsxbd(vmm_in, op); break; + case data_type::u8: uni_vpmovzxbd(vmm_in, op); break; + // For int4, we see two int4 as one int8 and extend them int32 + // low half stores in lower bytes of vmm and high half in higher + // bytes of vmm, then permute them into correct order + // Finally, we process the extend bytes for s4/u4 accordingly + case data_type::s4: + uni_vpmovsxbd(vmm_lower, op); + copy_half_int4(vmm_in, vmm_lower); + vpermd(vmm_in, vmm_permd, vmm_in); + uni_vpslld(vmm_in | k5555, vmm_in, 28); + vpsrad(vmm_in | k5555, vmm_in, 28); + vpsrad(vmm_in | kAAAA, vmm_in, 4); + break; + case data_type::u4: + uni_vpmovzxbd(vmm_lower, op); + copy_half_int4(vmm_in, vmm_lower); + vpermd(vmm_in, vmm_permd, vmm_in); + uni_vpslld(vmm_in | k5555, vmm_in, 28); + vpsrld(vmm_in | k5555, vmm_in, 28); + vpsrld(vmm_in | kAAAA, vmm_in, 4); + break; + default: assert(!"unsupported data type"); } } @@ -4345,17 +4529,13 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t::copy_block( const int k_blk = k / k_blk_step; const auto src_vmm0 = get_vmm(blk, 0); const auto src_vmm1 = get_vmm(blk, 1); - const dim_t offset = k_blk * src_stride_ + n * k_blk_step * typesize_; - const auto stride = n_blk_step * typesize_; + const dim_t offset = k_blk * src_stride_ + + (n * k_blk_step * typesize_) / typesize_scale_; + const auto stride = (n_blk_step * typesize_) / typesize_scale_; auto load_addr0 = maybe_EVEX_compress_addr(reg_src, offset); auto load_addr1 = maybe_EVEX_compress_addr(reg_src, offset + stride); - if (conf_->orig_wei_dt == data_type::s8) { - vpmovsxbd(src_vmm0, load_addr0); - vpmovsxbd(src_vmm1, load_addr1); - } else { - vpmovzxbd(src_vmm0, load_addr0); - vpmovzxbd(src_vmm1, load_addr1); - } + load_int(src_vmm0, load_addr0); + load_int(src_vmm1, load_addr1); if (req_zp_b_shift_) { vpsubd(src_vmm0, src_vmm0, vmm_zp_b_val); vpsubd(src_vmm1, src_vmm1, vmm_zp_b_val); diff --git a/src/cpu/x64/matmul/brgemm_matmul_reorders.cpp b/src/cpu/x64/matmul/brgemm_matmul_reorders.cpp index 52c70326945..15c73874591 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_reorders.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_reorders.cpp @@ -68,7 +68,8 @@ status_t init_conf(matmul::brgemm_matmul_conf_t &conf, const auto type_o = od.data_type(); const bool is_bf16_with_int_wei = type_o == data_type::bf16 - && utils::one_of(type_i, data_type::s8, data_type::u8); + && utils::one_of(type_i, data_type::s8, data_type::u8, + data_type::s4, data_type::u4); format_tag_t otag = get_otag(dst_md); // TODO: enable for itag = {ba, acb} @@ -80,6 +81,8 @@ status_t init_conf(matmul::brgemm_matmul_conf_t &conf, dim_t batch = ndims > 2 ? dims[ndims - 3] : 1; dim_t K = dims[ndims - 2]; dim_t N = dims[ndims - 1]; + if (utils::one_of(type_i, data_type::s4, data_type::u4) && N % 2 != 0) + return status::invalid_arguments; dim_t in_ld = ndims >= 2 ? memory_desc_wrapper(src_md).strides()[ndims - 2] : 1; @@ -112,18 +115,20 @@ status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init( const auto type_o = od.data_type(); // TODO: enable support for type_i != type_o cases + const bool is_int_weights = utils::one_of( + type_i, data_type::s8, data_type::u8, data_type::s4, data_type::u4); const bool dt_ok = true && IMPLICATION(type_i == type_o, utils::one_of(type_o, data_type::s8, data_type::bf16, data_type::f16, data_type::f32)) && IMPLICATION(type_i != type_o, - type_o == data_type::bf16 - && utils::one_of( - type_i, data_type::s8, data_type::u8)); + utils::one_of(type_o, data_type::f32, data_type::f16, + data_type::bf16) + && is_int_weights); const bool is_f16 = utils::one_of(data_type::f16, type_i, type_o); const bool is_s8s8 = type_i == data_type::s8 && type_o == data_type::s8; - const bool is_bf16_with_int_wei = type_o == data_type::bf16 - && utils::one_of(type_i, data_type::s8, data_type::u8); + const bool is_bf16_with_int_wei + = type_o == data_type::bf16 && is_int_weights; const bool has_adj_scale = od.extra().flags & memory_extra_flags::scale_adjust; const bool args_ok = true && dt_ok && id.is_dense() @@ -184,6 +189,10 @@ status_t brgemm_matmul_matrix_B_reorder_t::execute_body( const auto sdt_sz = types::data_type_size(src_d.data_type()); const auto type_o = dst_d.data_type(); const auto ddt_sz = types::data_type_size(type_o); + const auto src_typesz_scale + = utils::one_of(src_d.data_type(), data_type::s4, data_type::u4) + ? 2 + : 1; const auto &kernel_conf = pd()->matmul_conf_for_reorder_; const size_t comp_offset_bytes @@ -238,7 +247,8 @@ status_t brgemm_matmul_matrix_B_reorder_t::execute_body( ? get_blk_off(src_d, sdt_sz, batch, k, n) : get_blk_off( src_d, sdt_sz, batch, k_blk_idx, n_blk_idx); - ker_exec_ctx.src = (void *)&src[src_offset]; + ker_exec_ctx.src + = (void *)&src[src_offset / src_typesz_scale]; ker_exec_ctx.tr_src = (void *)&dst[get_blk_off( dst_d, ddt_sz, batch, k_blk_idx, n_blk_idx)]; ker_exec_ctx.current_K_start = k; @@ -251,7 +261,8 @@ status_t brgemm_matmul_matrix_B_reorder_t::execute_body( ? get_blk_off(src_d, sdt_sz, batch, k, n) : get_blk_off( src_d, sdt_sz, batch, k_blk_idx, n_blk_idx); - ker_exec_ctx.src = (void *)&src[src_offset]; + ker_exec_ctx.src + = (void *)&src[src_offset / src_typesz_scale]; const auto dst_offset = get_blk_off( dst_d, ddt_sz, batch, k_blk_idx, n_blk_idx); ker_exec_ctx.tr_src = (void *)&dst[dst_offset]; diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 712a52cd152..98ecef7fb6d 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -357,7 +357,9 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md, if (bgmmc.src_tag == format_tag::undef || (memory_desc_matches_tag(A_md, transposed_tensor_layout_tag) && memory_desc_matches_tag( - A_md, plain_tensor_layout_tag))) { + A_md, plain_tensor_layout_tag) + && IMPLICATION( + !is_adbc_allowed, is_int8_avx512_core))) { if (gemm_based::check_gemm_input_format(A_md)) { // Note: Here we batch layout may not be accurately represented // by the wei_tag string, due to all the permutations of the @@ -1393,10 +1395,14 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, && bgmmc.is_oscale_per_k && bgmmc.is_oscale_per_n && bgmmc.transposed_B; - // int4 weights decompression only supports plain layout for now - // TODO: enable int4 reorder and extend support to other weight layouts + // int4 weights decompression only supports plain layout + // and transpose layouts when K % 2 == 0 + // TODO: enable int4 reorder and extend support to blocked weights + // layout when needed if (bgmmc.with_wei_decompression && bgmmc.is_int4_weights) - VCONDCHECK_BG(bm_conf_utils.check_is_plain(bgmmc.wei_tag), + VCONDCHECK_BG(bm_conf_utils.check_is_plain(bgmmc.wei_tag) + || (bm_conf_utils.check_is_transposed(bgmmc.wei_tag) + && bgmmc.K % 2 == 0), VERBOSE_UNSUPPORTED_TAG); const bool transposed_A = bm_conf_utils.check_is_transposed(bgmmc.src_tag); @@ -1619,9 +1625,13 @@ status_t init_conf(brgemm_matmul_conf_t &conf, dim_t batch, dim_t K, dim_t N, if (vnni_granularity <= 0) return status::invalid_arguments; const bool is_bf16_with_int_wei = out_type == data_type::bf16 - && utils::one_of(in_type, data_type::s8, data_type::u8); + && utils::one_of(in_type, data_type::s8, data_type::u8, + data_type::s4, data_type::u4); const bool with_wei_decompression = in_type != out_type - && utils::one_of(in_type, data_type::s8, data_type::u8); + && utils::one_of(in_type, data_type::s8, data_type::u8, + data_type::s4, data_type::u4); + const dim_t typesize_scale + = utils::one_of(in_type, data_type::s4, data_type::u4) ? 2 : 1; conf.blocked_B = !utils::one_of(in_tag, ab, ba, abc, acb); conf.transposed_B = utils::one_of(in_tag, ba, acb); @@ -1641,7 +1651,7 @@ status_t init_conf(brgemm_matmul_conf_t &conf, dim_t batch, dim_t K, dim_t N, conf.a_dt_sz = conf.tr_a_dt_sz = types::data_type_size(conf.src_dt); conf.b_dt_sz = types::data_type_size(in_type); conf.tr_b_dt_sz = types::data_type_size(conf.wei_dt); - conf.copy_B_wei_stride = in_ld * conf.b_dt_sz; + conf.copy_B_wei_stride = (in_ld * conf.b_dt_sz) / typesize_scale; conf.N_chunk_elems = conf.N; // To match seems unneeded assert. conf.s8s8_comp_b_str = utils::rnd_up(conf.N, conf.wei_n_blk); conf.s8s8_comp_n_str = conf.wei_n_blk; @@ -1757,7 +1767,7 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc, * factor; } else if (bgmmc.transposed_B) { bgmmc.copy_B_wei_stride - = wei_d.strides()[bgmmc.ndims - 1] * bgmmc.b_dt_sz; + = (wei_d.strides()[bgmmc.ndims - 1] * bgmmc.b_dt_sz) / int4_fac; } else if (bgmmc.is_runtime_N) { bgmmc.copy_B_wei_stride = bgmmc.N; } else if (bgmmc.blocked_B) { diff --git a/src/cpu/x64/rnn/rnn_brgemm_utils.cpp b/src/cpu/x64/rnn/rnn_brgemm_utils.cpp index dd7979870d3..0cc7d3bf74a 100644 --- a/src/cpu/x64/rnn/rnn_brgemm_utils.cpp +++ b/src/cpu/x64/rnn/rnn_brgemm_utils.cpp @@ -84,7 +84,8 @@ x64::cpu_isa_t brgemm_calc_isa( } if (rnn.is_cell_dt_int8()) { - return x64::avx512_core_vnni; + return utils::map(true, x64::isa_undef, mayiuse(avx512_core_vnni), + avx512_core, mayiuse(avx512_core), avx512_core); } else if (rnn.is_cell_dt_bf16()) { return x64::avx512_core_bf16; } else if (rnn.is_cell_dt_f16()) { @@ -1339,7 +1340,7 @@ static status_t init_kernels_diff_wei(rnn_diff_wei_brgemm_t &diff_wei, tmp_matmul_conf_for_reorder.N_tail = diff_wei_conf.n_tail; tmp_matmul_conf_for_reorder.LDB = diff_wei_conf.LDB; tmp_matmul_conf_for_reorder.src_dt = tmp_matmul_conf_for_reorder.wei_dt - = rnn.cell_dt; + = tmp_matmul_conf_for_reorder.orig_wei_dt = rnn.cell_dt; tmp_matmul_conf_for_reorder.a_dt_sz = tmp_matmul_conf_for_reorder.tr_a_dt_sz = types::data_type_size(tmp_matmul_conf_for_reorder.src_dt); tmp_matmul_conf_for_reorder.b_dt_sz = tmp_matmul_conf_for_reorder.tr_b_dt_sz diff --git a/src/gpu/generic/sycl/matmul_kernels.hpp b/src/gpu/generic/sycl/matmul_kernels.hpp index 38e001f08dc..c0d54b32b7e 100644 --- a/src/gpu/generic/sycl/matmul_kernels.hpp +++ b/src/gpu/generic/sycl/matmul_kernels.hpp @@ -313,13 +313,12 @@ struct matmul_kernel_fwd_t { for (int v_el = 0; v_el < vec_len; v_el++) { off_po[dim1] += row; off_po[dim1 + 1] += col * vec_len + v_el; - ::sycl::vec binary_src_vals - = kernel->post_op_src_val(off_po); + data[row][col][v_el] + = post_ops.apply(data[row][col][v_el], + prev_dst.data[row][col][v_el], + kernel->po_args_, off_po); off_po[dim1] -= row; off_po[dim1 + 1] -= col * vec_len + v_el; - data[row][col][v_el] = post_ops.apply( - data[row][col][v_el], - prev_dst.data[row][col][v_el], binary_src_vals); } } } @@ -378,16 +377,7 @@ struct matmul_kernel_fwd_t { , dropout_seed_(CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_ATTR_DROPOUT_SEED)) , dropout_probability_( CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_ATTR_DROPOUT_PROBABILITY)) - , po1_src_(CTX_IN_SYCL_KERNEL_MEMORY( - (DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1))) - , po2_src_(CTX_IN_SYCL_KERNEL_MEMORY( - (DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1))) - , po3_src_(CTX_IN_SYCL_KERNEL_MEMORY( - (DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1))) - , po4_src_(CTX_IN_SYCL_KERNEL_MEMORY( - (DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1))) - , po5_src_(CTX_IN_SYCL_KERNEL_MEMORY( - (DNNL_ARG_ATTR_MULTIPLE_POST_OP(4) | DNNL_ARG_SRC_1))) {} + , po_args_(cgh, ctx) {} void operator()(::sycl::nd_item<1> item) const { using data_block_t = register_block; @@ -623,77 +613,6 @@ struct matmul_kernel_fwd_t { } private: - inline ::sycl::vec post_op_src_val(dims_t data_off) const { - ::sycl::vec post_po_sr; - const auto maxPostPo = conf_.post_ops.get_post_op(); - - for (dim_t po_idx = 0; po_idx < maxPostPo; po_idx++) { - float res = 0.0f; - if (po_idx == 0) - res = get_post_op_val(po1_src_, po_idx, data_off); - else if (po_idx == 1) - res = get_post_op_val(po2_src_, po_idx, data_off); - else if (po_idx == 2) - res = get_post_op_val(po3_src_, po_idx, data_off); - else if (po_idx == 3) - res = get_post_op_val(po4_src_, po_idx, data_off); - else if (po_idx == 4) - res = get_post_op_val(po5_src_, po_idx, data_off); - - post_po_sr[po_idx] = res; - } - return post_po_sr; - } - - float get_post_op_val(const xpu::sycl::in_memory_arg_t &bin_src_op, - dim_t &idx, dims_t offset) const { - auto src1_desc = conf_.binary_src_arr[idx]; - - xpu::sycl::md_t::dim32_t ndims = conf_.dst_md.ndims(); - xpu::sycl::md_t::dims32_t dst_dims; - for (int i = 0; i < ndims; i++) { - dst_dims[i] = conf_.dst_md.dims()[i]; - } - if (conf_.transpose_dst) { - std::swap(dst_dims[ndims - 1], dst_dims[ndims - 2]); - } - const auto off - = get_matmul_src1_off(src1_desc, offset, dst_dims, ndims); - - auto dst = load_float_value( - src1_desc.data_type(), bin_src_op.get_pointer(), off); - return dst; - } - - dim_t get_matmul_src1_off(const xpu::sycl::md_t &src1_md, dims_t offset, - const xpu::sycl::md_t::dims32_t &dst_dims, - const xpu::sycl::md_t::dim32_t &dst_ndims) const { - const dim_t mask_matmul_po - = get_dims_mask(dst_dims, src1_md.dims(), dst_ndims); - return get_po_tensor_off( - src1_md, offset, dst_dims, dst_ndims, mask_matmul_po); - } - - inline dim_t get_dims_mask(const xpu::sycl::md_t::dims32_t &dims1, - const xpu::sycl::md_t::dims32_t &dims2, const dim_t &ndims, - bool skip_dim_of_one = false) const { - dim_t mask = 0; - for (dim_t d = 0; d < ndims; ++d) { - // Disable mask_bit for dimensions of `1` by request. - dim_t mask_bit = skip_dim_of_one && dims1[d] == 1 ? 0 : (1 << d); - mask += dims1[d] == dims2[d] ? mask_bit : 0; - } - return mask; - } - - inline dim_t get_po_tensor_off(const xpu::sycl::md_t &tensor_md, - dims_t offset, const xpu::sycl::md_t::dims32_t &dst_dims, - const dim_t &dst_ndims, const dim_t &mask) const { - dims_t offset_po {}; - utils::copy_dims_with_mask(offset_po, offset, dst_ndims, mask); - return tensor_md.off_v(offset_po); - } - sycl_matmul_conf_t conf_; xpu::sycl::in_memory_arg_t data_; @@ -715,11 +634,7 @@ struct matmul_kernel_fwd_t { xpu::sycl::out_memory_arg_t dropout_mask_; xpu::sycl::in_memory_arg_t dropout_seed_; xpu::sycl::in_memory_arg_t dropout_probability_; - xpu::sycl::in_memory_arg_t po1_src_; - xpu::sycl::in_memory_arg_t po2_src_; - xpu::sycl::in_memory_arg_t po3_src_; - xpu::sycl::in_memory_arg_t po4_src_; - xpu::sycl::in_memory_arg_t po5_src_; + post_op_input_args po_args_; }; } // namespace sycl diff --git a/src/gpu/generic/sycl/pooling_kernels.hpp b/src/gpu/generic/sycl/pooling_kernels.hpp index 4d6ec67b075..6b515aa6537 100644 --- a/src/gpu/generic/sycl/pooling_kernels.hpp +++ b/src/gpu/generic/sycl/pooling_kernels.hpp @@ -114,7 +114,7 @@ struct pooling_fwd_kernel_vec_t { return 0; } float data_conv() const { - switch (src_md().data_type()) { + switch (dst_md().data_type()) { case data_type::bf16: return (float) std::numeric_limits::lowest(); @@ -255,49 +255,71 @@ struct pooling_bwd_kernel_vec_t { memory_tensor_t diff_dst_mem(diff_dst_, conf_.diff_dst_md); size_t ithr = item.get_group(0) * conf_.wg_size + item.get_local_id(); - - dim_t ow_start = max(dim_t(0), - math::div_up( - conf_.padL - ((conf_.KW - 1) * conf_.DW + conf_.KW) + 1, - conf_.SW)); - dim_t ow_end - = min(conf_.OW, 1 + (conf_.padL + conf_.IW - 1) / conf_.SW); - - dim_t oh_start = max(dim_t(0), - math::div_up( - conf_.padT - ((conf_.KH - 1) * conf_.DH + conf_.KH) + 1, - conf_.SH)); - dim_t oh_end - = min(conf_.OH, 1 + (conf_.padT + conf_.IH - 1) / conf_.SH); - - dim_t od_start = max(dim_t(0), - math::div_up( - conf_.padF - ((conf_.KD - 1) * conf_.DD + conf_.KD) + 1, - conf_.SD)); - dim_t od_end - = min(conf_.OD, 1 + (conf_.padF + conf_.ID - 1) / conf_.SD); + int denom = 1; const bool is_max_pool = conf_.alg == alg_kind::pooling_max; - dim_t MB = conf_.MB; - dim_t OC = conf_.OC; - const dim_t work_amount = MB * OC; + const bool is_avg_incl_pad + = conf_.alg == alg_kind::pooling_avg_include_padding; + const bool is_avg_excl_pad + = conf_.alg == alg_kind::pooling_avg_exclude_padding; + if (is_avg_incl_pad) denom = conf_.KW * conf_.KH * conf_.KD; + + const dim_t work_amount + = conf_.MB * conf_.OC * conf_.ID * conf_.IH * conf_.IW; if (work_amount == 0) return; dim_t start {0}, end {0}; balance211(work_amount, conf_.n_thr, ithr, start, end); - dim_t mb {0}, oc {0}; - utils::nd_iterator_init(start, mb, MB, oc, OC); + dim_t mb {0}, oc {0}, id {0}, ih {0}, iw {0}; + utils::nd_iterator_init(start, mb, conf_.MB, oc, conf_.OC, id, conf_.ID, + ih, conf_.IH, iw, conf_.IW); for (dim_t iwork = start; iwork < end; ++iwork) { - ker_zero(diff_src_mem, mb, oc); - for_(dim_t od = od_start; od < od_end; ++od) - for_(dim_t oh = oh_start; oh < oh_end; ++oh) - for (dim_t ow = ow_start; ow < ow_end; ++ow) { - if (is_max_pool) { - ker_max(diff_src_mem, diff_dst_mem, mb, oc, od, oh, ow); - } else { - ker_avg(diff_src_mem, diff_dst_mem, mb, oc, od, oh, ow); + float s = 0; + for (dim_t kd = 0; kd < conf_.KD; ++kd) { + dim_t _od = id + conf_.padF - kd * (conf_.DD + 1); + if (_od % conf_.SD != 0) continue; + dim_t od = _od / conf_.SD; + if (od < 0 || od >= conf_.OD) continue; + + for (dim_t kh = 0; kh < conf_.KH; ++kh) { + dim_t _oh = ih + conf_.padT - kh * (conf_.DH + 1); + if (_oh % conf_.SH != 0) continue; + dim_t oh = _oh / conf_.SH; + if (oh < 0 || oh >= conf_.OH) continue; + + for (dim_t kw = 0; kw < conf_.KW; ++kw) { + dim_t _ow = iw + conf_.padL - kw * (conf_.DW + 1); + if (_ow % conf_.SW != 0) continue; + dim_t ow = _ow / conf_.SW; + if (ow < 0 || ow >= conf_.OW) continue; + + const auto dst_off + = get_offset(diff_dst_md(), mb, oc, od, oh, ow); + if (is_max_pool) { + memory_tensor_t ws_mem(ws_, conf_.ws_md); + const int index = ws_mem.load(dst_off); + + const dim_t hw = index % (conf_.KW * conf_.KH); + const dim_t w_kd = index / (conf_.KW * conf_.KH); + const dim_t w_kw = hw % conf_.KW; + const dim_t w_kh = hw / conf_.KW; + if (w_kd != kd || w_kh != kh || w_kw != kw) + continue; + } + if (is_avg_excl_pad) { + ker_avg_excl_pad(od, oh, ow, diff_dst_mem, dst_off, + denom, s); + } + if (is_max_pool || is_avg_incl_pad) + s = s + diff_dst_mem.load(dst_off); + } } } - utils::nd_iterator_step(mb, MB, oc, OC); + const auto diff_src_offset + = get_offset(diff_src_md(), mb, oc, id, ih, iw); + if (is_max_pool || is_avg_incl_pad) s = s / denom; + diff_src_mem.store(s, diff_src_offset); + utils::nd_iterator_step(mb, conf_.MB, oc, conf_.OC, id, conf_.ID, + ih, conf_.IH, iw, conf_.IW); } } @@ -317,104 +339,39 @@ struct pooling_bwd_kernel_vec_t { return 0; } - void ker_zero(out_memory_tensor_t &diff_src_mem, dim_t mb, dim_t oc) const { - for_(dim_t id = 0; id < conf_.ID; ++id) - for_(dim_t ih = 0; ih < conf_.IH; ++ih) - for (dim_t iw = 0; iw < conf_.IW; ++iw) { - const auto off = get_offset(diff_src_md(), mb, oc, id, ih, iw); - diff_src_mem.store(0, off); - } - } - - void ker_max(out_memory_tensor_t &diff_src_mem, - const in_memory_tensor_t &diff_dst_mem, dim_t mb, dim_t oc, - dim_t od, dim_t oh, dim_t ow) const { - memory_tensor_t ws_mem(ws_, conf_.ws_md); - const auto ws_off = get_offset(ws_md(), mb, oc, od, oh, ow); - - const int index = ws_mem.load(ws_off); - const dim_t kd = (index / conf_.KW) / conf_.KH; - const dim_t kh = (index / conf_.KW) % conf_.KH; - const dim_t kw = index % conf_.KW; - const dim_t id = od * conf_.SD - conf_.padF + kd * (conf_.DD + 1); - const dim_t ih = oh * conf_.SH - conf_.padT + kh * (conf_.DH + 1); - const dim_t iw = ow * conf_.SW - conf_.padL + kw * (conf_.DW + 1); - if (id < 0 || id >= conf_.ID) return; - if (ih < 0 || ih >= conf_.IH) return; - if (iw < 0 || iw >= conf_.IW) return; - - const auto d_src_off = get_offset(diff_src_md(), mb, oc, id, ih, iw); - const auto d_dst_off = get_offset(diff_dst_md(), mb, oc, od, oh, ow); - float v_src = diff_src_mem.load(d_src_off); - float v_dst = diff_dst_mem.load(d_dst_off); - v_src += v_dst; - diff_src_mem.store(v_src, d_src_off); - } - - void ker_avg(out_memory_tensor_t &diff_src_mem, - const in_memory_tensor_t &diff_dst_mem, dim_t mb, dim_t oc, - dim_t od, dim_t oh, dim_t ow) const { - int num_summands; - if (conf_.alg == alg_kind::pooling_avg_include_padding) - num_summands = conf_.KW * conf_.KH * conf_.KD; - else { - auto id_start = od * conf_.SD - conf_.padF; - auto ih_start = oh * conf_.SH - conf_.padT; - auto iw_start = ow * conf_.SW - conf_.padL; - auto id_end = od * conf_.SD - conf_.padF + (conf_.KD - 1) * conf_.DD - + conf_.KD; - auto ih_end = oh * conf_.SH - conf_.padT + (conf_.KH - 1) * conf_.DH - + conf_.KH; - auto iw_end = ow * conf_.SW - conf_.padL + (conf_.KW - 1) * conf_.DW - + conf_.KW; - - auto id_start_excluded = id_start < 0 - ? (0 - id_start - 1) / (conf_.DD + 1) + 1 - : 0; - auto ih_start_excluded = ih_start < 0 - ? (0 - ih_start - 1) / (conf_.DH + 1) + 1 - : 0; - auto iw_start_excluded = iw_start < 0 - ? (0 - iw_start - 1) / (conf_.DW + 1) + 1 - : 0; - auto id_end_excluded = id_end > conf_.ID - ? (id_end - conf_.ID - 1) / (conf_.DD + 1) + 1 - : 0; - auto ih_end_excluded = ih_end > conf_.IH - ? (ih_end - conf_.IH - 1) / (conf_.DH + 1) + 1 - : 0; - auto iw_end_excluded = iw_end > conf_.IW - ? (iw_end - conf_.IW - 1) / (conf_.DW + 1) + 1 - : 0; - - num_summands = (conf_.KD - id_start_excluded - id_end_excluded) - * (conf_.KH - ih_start_excluded - ih_end_excluded) - * (conf_.KW - iw_start_excluded - iw_end_excluded); - } - for (dim_t kd = 0; kd < conf_.KD; ++kd) { - const dim_t id = od * conf_.SD - conf_.padF + kd * (conf_.DD + 1); - if (id < 0 || id >= conf_.ID) continue; - for (dim_t kh = 0; kh < conf_.KH; ++kh) { - const dim_t ih - = oh * conf_.SH - conf_.padT + kh * (conf_.DH + 1); - if (ih < 0 || ih >= conf_.IH) continue; - for (dim_t kw = 0; kw < conf_.KW; ++kw) { - const dim_t iw - = ow * conf_.SW - conf_.padL + kw * (conf_.DW + 1); - if (iw < 0 || iw >= conf_.IW) continue; - - const auto d_src_off - = get_offset(diff_src_md(), mb, oc, id, ih, iw); - const auto d_dst_off - = get_offset(diff_dst_md(), mb, oc, od, oh, ow); - float v_src = diff_src_mem.load(d_src_off); - ; - float v_dst = diff_dst_mem.load(d_dst_off); - v_src += v_dst / num_summands; - diff_src_mem.store(v_src, d_src_off); - } - } - } + void ker_avg_excl_pad(dim_t od, dim_t oh, dim_t ow, + const in_memory_tensor_t &diff_dst_mem, dim_t dst_off, int &denom, + float &s) const { + const auto id_start = od * conf_.SD - conf_.padF; + const auto ih_start = oh * conf_.SH - conf_.padT; + const auto iw_start = ow * conf_.SW - conf_.padL; + const auto id_end = od * conf_.SD - conf_.padF + + (conf_.KD - 1) * conf_.DD + conf_.KD; + const auto ih_end = oh * conf_.SH - conf_.padT + + (conf_.KH - 1) * conf_.DH + conf_.KH; + const auto iw_end = ow * conf_.SW - conf_.padL + + (conf_.KW - 1) * conf_.DW + conf_.KW; + + const auto id_start_excluded + = id_start < 0 ? (0 - id_start - 1) / (conf_.DD + 1) + 1 : 0; + const auto ih_start_excluded + = ih_start < 0 ? (0 - ih_start - 1) / (conf_.DH + 1) + 1 : 0; + const auto iw_start_excluded + = iw_start < 0 ? (0 - iw_start - 1) / (conf_.DW + 1) + 1 : 0; + const auto id_end_excluded = id_end > conf_.ID + ? (id_end - conf_.ID - 1) / (conf_.DD + 1) + 1 + : 0; + const auto ih_end_excluded = ih_end > conf_.IH + ? (ih_end - conf_.IH - 1) / (conf_.DH + 1) + 1 + : 0; + const auto iw_end_excluded = iw_end > conf_.IW + ? (iw_end - conf_.IW - 1) / (conf_.DW + 1) + 1 + : 0; + + denom = (conf_.KD - id_start_excluded - id_end_excluded) + * (conf_.KH - ih_start_excluded - ih_end_excluded) + * (conf_.KW - iw_start_excluded - iw_end_excluded); + s = s + (diff_dst_mem.load(dst_off) / denom); } sycl_pooling_bwd_conf_t conf_; diff --git a/src/gpu/generic/sycl/ref_binary.hpp b/src/gpu/generic/sycl/ref_binary.hpp index f83b75a8543..3a7bac750f8 100644 --- a/src/gpu/generic/sycl/ref_binary.hpp +++ b/src/gpu/generic/sycl/ref_binary.hpp @@ -56,7 +56,8 @@ struct ref_binary_t : public gpu::generic::sycl::primitive_t { sm::scales_runtime | sm::post_ops) && IMPLICATION(!attr()->scales_.has_default_values(), check_scales_mask()) - && post_ops_ok() && md_dims_in_range(src_md(0)) + && sycl_post_ops_t::post_ops_ok(attr()) + && md_dims_in_range(src_md(0)) && md_dims_in_range(src_md(1)) && md_dims_in_range(dst_md()); if (!ok) return status::unimplemented; @@ -75,15 +76,6 @@ struct ref_binary_t : public gpu::generic::sycl::primitive_t { return attr_scales_ok(supported_args); } - bool post_ops_ok() const { - // Dw conv post-ops are not supported. - return attr()->post_ops_.len() <= sycl_post_ops_t::max_post_ops - && attr()->post_ops_.has_default_values( - {primitive_kind::eltwise, primitive_kind::binary, - primitive_kind::prelu, - primitive_kind::sum}); - } - static bool check_data_types(const memory_desc_wrapper &src0, const memory_desc_wrapper &src1, const memory_desc_wrapper &dst) { diff --git a/src/gpu/generic/sycl/ref_convolution.hpp b/src/gpu/generic/sycl/ref_convolution.hpp index 9a003258646..b70f05c331e 100644 --- a/src/gpu/generic/sycl/ref_convolution.hpp +++ b/src/gpu/generic/sycl/ref_convolution.hpp @@ -93,7 +93,7 @@ struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t { | sm::sum_dt) && IMPLICATION(!attr()->scales_.has_default_values(), attr_scales_ok()) - && post_ops_ok(); + && sycl_post_ops_t::post_ops_ok(attr(), false); if (!ok) return status::unimplemented; return init_conf(); @@ -112,24 +112,6 @@ struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t { : utils::pick(ndims() - 3, oiw, oihw, oidhw); return set_default_formats_common(dat_tag, wei_tag, dat_tag); } - - bool post_ops_ok() const { - for (int i = 0; i < attr()->post_ops_.len(); i++) { - const auto &e = attr()->post_ops_.entry_[i]; - if (!IMPLICATION(e.is_eltwise(), - utils::one_of(e.eltwise.alg, alg_kind::eltwise_relu, - alg_kind::eltwise_linear, - alg_kind::eltwise_clip, - alg_kind::eltwise_clip_v2, - alg_kind::eltwise_hardswish))) { - return false; - } - } - return attr()->post_ops_.len() <= sycl_post_ops_t::max_post_ops - && attr()->post_ops_.has_default_values( - {primitive_kind::eltwise, primitive_kind::prelu, - primitive_kind::sum}); - } }; status_t init(impl::engine_t *engine) override; diff --git a/src/gpu/generic/sycl/ref_eltwise.hpp b/src/gpu/generic/sycl/ref_eltwise.hpp index 52bcec6e86e..3b854697c96 100644 --- a/src/gpu/generic/sycl/ref_eltwise.hpp +++ b/src/gpu/generic/sycl/ref_eltwise.hpp @@ -51,7 +51,8 @@ struct ref_sycl_eltwise_fwd_t : public gpu::generic::sycl::primitive_t { && attr()->has_default_values(sm::post_ops) && set_default_formats_common() && src_d == dst_d && attr_.set_default_formats(dst_md(0)) == status::success - && post_ops_ok() && md_dims_in_range(src_md()); + && sycl_post_ops_t::post_ops_ok(attr()) + && md_dims_in_range(src_md()); if (!ok) return status::unimplemented; return init_conf(); @@ -74,27 +75,6 @@ struct ref_sycl_eltwise_fwd_t : public gpu::generic::sycl::primitive_t { return true; } - - bool post_ops_ok() const { - for (int i = 0; i < attr()->post_ops_.len(); i++) { - const auto &e = attr()->post_ops_.entry_[i]; - if (!IMPLICATION(e.is_binary(), - utils::one_of(e.binary.alg, alg_kind::binary_add, - alg_kind::binary_div, alg_kind::binary_mul, - alg_kind::binary_sub, alg_kind::binary_max, - alg_kind::binary_min, alg_kind::binary_ge, - alg_kind::binary_gt, alg_kind::binary_le, - alg_kind::binary_lt, alg_kind::binary_eq, - alg_kind::binary_ne))) { - - return false; - } - } - return attr()->post_ops_.len() <= sycl_post_ops_t::max_post_ops - && attr()->post_ops_.has_default_values( - {primitive_kind::eltwise, primitive_kind::binary, - primitive_kind::sum}); - } }; status_t init(impl::engine_t *engine) override; diff --git a/src/gpu/generic/sycl/ref_reorder.hpp b/src/gpu/generic/sycl/ref_reorder.hpp index 4b5c4d858cb..620faf5fe96 100644 --- a/src/gpu/generic/sycl/ref_reorder.hpp +++ b/src/gpu/generic/sycl/ref_reorder.hpp @@ -54,7 +54,8 @@ struct ref_reorder_t : public gpu::generic::sycl::primitive_t { && check_formats(src_d, dst_d) && attr()->has_default_values( sm::scales_runtime | sm::post_ops) - && post_ops_ok() && md_dims_in_range(dst_md()); + && sycl_post_ops_t::post_ops_ok(attr()) + && md_dims_in_range(dst_md()); if (!ok) return status::unimplemented; return init_conf(); @@ -67,15 +68,6 @@ struct ref_reorder_t : public gpu::generic::sycl::primitive_t { status_t init_conf(); - bool post_ops_ok() const { - for (int i = 0; i < attr()->post_ops_.len(); i++) { - if (!attr()->post_ops_.entry_[i].is_sum()) { return false; } - } - return attr()->post_ops_.len() <= sycl_post_ops_t::max_post_ops - && attr()->post_ops_.has_default_values( - {primitive_kind::sum}); - } - static bool check_data_types(const memory_desc_wrapper &src, const memory_desc_wrapper &dst) { using namespace data_type; diff --git a/src/gpu/generic/sycl/ref_softmax.hpp b/src/gpu/generic/sycl/ref_softmax.hpp index 36b02aed48d..17464aba5d7 100644 --- a/src/gpu/generic/sycl/ref_softmax.hpp +++ b/src/gpu/generic/sycl/ref_softmax.hpp @@ -44,7 +44,8 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t { && (src_md(0)->format_desc.blocking.inner_nblks == 0) && attr()->has_default_values( sm::scales_runtime | sm::post_ops) - && attr_oscale_ok() && post_ops_ok() + && attr_oscale_ok() + && sycl_post_ops_t::post_ops_ok(attr(), true, false) && set_default_formats() == status::success && attr_.set_default_formats(dst_md()) == status::success && md_dims_in_range(src_md()); @@ -65,12 +66,6 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t { return ok; } - bool post_ops_ok() const { - return attr()->post_ops_.len() <= sycl_post_ops_t::max_post_ops - && attr()->post_ops_.has_default_values( - {primitive_kind::eltwise, primitive_kind::binary}); - } - bool check_data_types(data_type_t src) { return utils::one_of(src, data_type::f32, data_type::bf16, data_type::f16, data_type::s8, data_type::u8); diff --git a/src/gpu/generic/sycl/ref_sum.hpp b/src/gpu/generic/sycl/ref_sum.hpp index f25d87a11f2..18ee5afd5ee 100644 --- a/src/gpu/generic/sycl/ref_sum.hpp +++ b/src/gpu/generic/sycl/ref_sum.hpp @@ -21,7 +21,6 @@ #include "common/stream.hpp" #include "gpu/generic/sycl/sycl_gpu_primitive.hpp" #include "gpu/generic/sycl/sycl_io_helper.hpp" -#include "gpu/generic/sycl/sycl_post_ops.hpp" #include "gpu/generic/sycl/sycl_primitive_conf.hpp" #include "gpu/generic/sycl/sycl_q10n.hpp" #include "gpu/gpu_sum_pd.hpp" diff --git a/src/gpu/generic/sycl/sycl_post_ops.hpp b/src/gpu/generic/sycl/sycl_post_ops.hpp index ca32f775cf1..2de1d1d1588 100644 --- a/src/gpu/generic/sycl/sycl_post_ops.hpp +++ b/src/gpu/generic/sycl/sycl_post_ops.hpp @@ -31,14 +31,21 @@ namespace generic { namespace sycl { struct ref_eltwise_fwd_t { - ref_eltwise_fwd_t() = default; - ref_eltwise_fwd_t(alg_kind_t alg, float alpha, float beta, float scale) - : alg_(alg), alpha_(alpha), beta_(beta), scale_(scale) { + static bool eltwise_ok(alg_kind_t alg) { using namespace alg_kind; - assert(utils::one_of(alg_, eltwise_relu, eltwise_linear, eltwise_clip, + return utils::one_of(alg, eltwise_relu, eltwise_linear, eltwise_clip, eltwise_clip_v2, eltwise_hardswish, eltwise_gelu_tanh, eltwise_gelu_erf, eltwise_tanh, eltwise_logistic, eltwise_swish, - eltwise_elu)); + eltwise_elu); + } + static bool eltwise_ok(const post_ops_t::entry_t::eltwise_t &eltwise) { + return eltwise_ok(eltwise.alg); + } + + ref_eltwise_fwd_t() = default; + ref_eltwise_fwd_t(alg_kind_t alg, float alpha, float beta, float scale) + : alg_(alg), alpha_(alpha), beta_(beta), scale_(scale) { + assert(eltwise_ok(alg)); } ref_eltwise_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise) @@ -106,13 +113,20 @@ struct ref_eltwise_fwd_t { }; struct ref_binary_op_t { + static bool binary_ok(alg_kind_t alg) { + using namespace alg_kind; + return utils::one_of(alg, binary_add, binary_div, binary_max, + binary_min, binary_mul, binary_sub, binary_ge, binary_gt, + binary_le, binary_lt, binary_eq, binary_ne); + } + static bool binary_ok(const post_ops_t::entry_t::binary_t &binary) { + return binary_ok(binary.alg); + } + ref_binary_op_t() = default; ref_binary_op_t(alg_kind_t alg, xpu::sycl::md_t src_md) : alg_(alg), src_md_(src_md) { - using namespace alg_kind; - assert(utils::one_of(alg_, binary_add, binary_div, binary_max, - binary_min, binary_mul, binary_sub, binary_ge, binary_gt, - binary_le, binary_lt, binary_eq, binary_ne)); + assert(binary_ok(alg)); } ref_binary_op_t(const post_ops_t::entry_t::binary_t &binary) @@ -202,6 +216,28 @@ struct sycl_post_ops_t { // the number of post ops. static constexpr int max_post_ops = 5; + static bool post_ops_ok(const primitive_attr_t *attr, + bool allow_inputs = true, bool allow_sum = true) { + using namespace primitive_kind; + const auto &attr_po = attr->post_ops_; + if (attr_po.len() > max_post_ops) { return false; } + for (auto i = 0; i < attr_po.len(); ++i) { + if (allow_sum && attr_po.contain(sum, i)) { + } else if (attr_po.contain(eltwise, i)) { + if (!ref_eltwise_fwd_t::eltwise_ok(attr_po.entry_[i].eltwise)) { + return false; + } + } else if (allow_inputs && attr_po.contain(binary, i)) { + if (!ref_binary_op_t::binary_ok(attr_po.entry_[i].binary)) { + return false; + } + } else { + return false; + } + } + return true; + } + sycl_post_ops_t() = default; sycl_post_ops_t(const primitive_attr_t *attr, dnnl::impl::data_type_t dst_dt = dnnl_data_type_undef) { @@ -232,6 +268,8 @@ struct sycl_post_ops_t { inline float apply(float acc, const xpu::sycl::out_memory_arg_t &dst, dim_t dst_offset, const post_op_input_args &po_args, dims_t src_offset) const; + inline float apply(float acc, float dst, const post_op_input_args &po_args, + dims_t src_offset) const; inline float apply(float acc, const post_op_input_args &po_args, dims_t src_offset) const; inline float apply(float acc, const xpu::sycl::out_memory_arg_t &dst, @@ -295,6 +333,24 @@ float sycl_post_ops_t::apply(float acc, const xpu::sycl::out_memory_arg_t &dst, return acc; } +float sycl_post_ops_t::apply(float acc, float dst, + const post_op_input_args &po_args, dims_t src_offset) const { + using namespace primitive_kind; + + for (auto i = 0; i < n_post_ops_; ++i) { + switch (ops_[i].kind_) { + case eltwise: acc = ops_[i].eltwise_.compute(acc); break; + case binary: + acc = ops_[i].binary_.load_and_compute( + acc, po_args.args_[i], src_offset); + break; + case sum: acc = ops_[i].sum_.compute(acc, dst); break; + default: acc = ::sycl::nan(0u); + } + } + return acc; +} + float sycl_post_ops_t::apply( float acc, const post_op_input_args &po_args, dims_t src_offset) const { using namespace primitive_kind; diff --git a/src/gpu/intel/compute/compute_stream.hpp b/src/gpu/intel/compute/compute_stream.hpp index 574c0379858..08841d3070e 100644 --- a/src/gpu/intel/compute/compute_stream.hpp +++ b/src/gpu/intel/compute/compute_stream.hpp @@ -36,6 +36,10 @@ class compute_stream_t : public gpu::stream_t { status_t notify_profiling_complete() const override; + virtual status_t barrier() = 0; + virtual status_t enter_immediate_mode() { return status::success; } + virtual status_t exit_immediate_mode() { return status::success; } + protected: bool has_zero_pad_primitive() const { return engine()->kind() == dnnl_gpu; diff --git a/src/gpu/intel/compute/dispatch.cpp b/src/gpu/intel/compute/dispatch.cpp index 70a9211a197..8b043ee10cc 100644 --- a/src/gpu/intel/compute/dispatch.cpp +++ b/src/gpu/intel/compute/dispatch.cpp @@ -131,7 +131,7 @@ void dispatch_t::define_dim_with_nesting_level( di.nesting_level = nesting_level; di.vector_size = 1; di.gws_index = -1; - dims_[ndims_] = di; + dims_[ndims_] = std::move(di); ++ndims_; } @@ -311,7 +311,7 @@ void dispatch_t::generate(bool generate_lws) { for (int i = vec_dim_idx - 1; i >= group_beg; --i) { dims_[i + 1] = dims_[i]; } - dims_[group_beg] = vec_dim_info; + dims_[group_beg] = std::move(vec_dim_info); } } diff --git a/src/gpu/intel/compute/kernel_ctx.hpp b/src/gpu/intel/compute/kernel_ctx.hpp index c4a1c1e3b8d..d20003b2f4b 100644 --- a/src/gpu/intel/compute/kernel_ctx.hpp +++ b/src/gpu/intel/compute/kernel_ctx.hpp @@ -124,7 +124,7 @@ class kernel_ctx_t { void add_custom_header( const std::string &header_name, std::string &&source) { - custom_headers_[header_name] = source; + custom_headers_[header_name] = std::move(source); } const char *get_custom_header(const std::string &header_name) const { diff --git a/src/gpu/intel/compute/zero_pool.cpp b/src/gpu/intel/compute/zero_pool.cpp index e92c0e6b879..b25577f2092 100644 --- a/src/gpu/intel/compute/zero_pool.cpp +++ b/src/gpu/intel/compute/zero_pool.cpp @@ -19,6 +19,10 @@ #include "gpu/intel/compute/zero_pool.hpp" +#ifdef DNNL_WITH_SYCL +#include "gpu/intel/sycl/stream.hpp" +#endif + namespace dnnl { namespace impl { namespace gpu { @@ -27,8 +31,26 @@ namespace intel { static std::unordered_map zero_pool_cache; static std::mutex zero_pool_cache_mutex; +#ifdef DNNL_WITH_SYCL +// Unfortunately, weak_ptrs cannot be hashed, so unordered_map not possible here. +// SYCL is currently missing owner_less for command graphs, so define it ourselves. +struct weak_graph_owner_less { + bool operator()(const sycl::stream_t::weak_graph_t &lhs, + const sycl::stream_t::weak_graph_t &rhs) const noexcept { + return lhs.owner_before(rhs); + } +}; +static std::map + recorded_zero_pool_cache; +#endif + struct cleanup_sentinel_t { cleanup_sentinel_t(bool *ptr) : ptr_(ptr) {} + cleanup_sentinel_t(const cleanup_sentinel_t &) = delete; + cleanup_sentinel_t(cleanup_sentinel_t &&other) = delete; + cleanup_sentinel_t &operator=(const cleanup_sentinel_t &) = delete; + cleanup_sentinel_t &operator=(cleanup_sentinel_t &&other) = delete; ~cleanup_sentinel_t() { *ptr_ = true; } private: @@ -42,11 +64,39 @@ static bool in_cleanup() { return destroyed; } -status_t lookup_zero_pool(compute::compute_engine_t *engine, size_t chunk_size, +status_t lookup_zero_pool(compute::compute_engine_t *engine, + compute::compute_stream_t *stream, size_t chunk_size, zero_pool_t **out_pool) { status_t status = status::success; (void)in_cleanup(); + +#ifdef DNNL_WITH_SYCL + // If recording, get a per-graph zero pool. + const auto *sycl_stream + = utils::downcast(stream); + if (sycl_stream->recording()) { + { + std::lock_guard lock(zero_pool_cache_mutex); + auto &pool = recorded_zero_pool_cache + [sycl_stream->get_current_graph_weak()]; + if (!pool) { + pool = new zero_pool_t(engine, chunk_size, true, + stream->flags() & stream_flags::in_order); + status = pool->init(); + + // Short-term hack: intentionally leak pool as + // we cannot know when graph is no longer in use. + pool->attach_client(); + } + *out_pool = pool; + } + (*out_pool)->attach_client(); + return status; + } +#endif + + // In regular mode, find a per-engine zero pool. auto engine_id = engine->engine_id(); { @@ -84,9 +134,14 @@ void release_zero_pool(zero_pool_t *pool) { } } -zero_pool_t::zero_pool_t( - compute::compute_engine_t *engine, size_t chunk_size, int chunk_count) - : engine_(engine), chunk_size_(chunk_size), chunk_count_(chunk_count) { +zero_pool_t::zero_pool_t(compute::compute_engine_t *engine, size_t chunk_size, + bool stream_private, bool in_order) + : engine_(engine) + , chunk_size_(chunk_size) + , stream_private_(stream_private) + , in_order_(in_order) { + + chunk_count_ = stream_private ? 1 : 16; assert(chunk_count_ <= max_chunks); } @@ -144,21 +199,28 @@ status_t zero_pool_t::claim(compute::compute_stream_t *stream, size_t len, if (!inited_) { // One-time zero initialization before first use. + // Use immediate mode to ensure zero initialization + // occurs now and is not recorded. + stream->enter_immediate_mode(); stream->fill(*mem_, 0, chunk_count_ * chunk_size_, stream->ctx().get_deps(), stream->ctx().get_deps()); + stream->exit_immediate_mode(); inited_ = true; } auto slot = next_slot_++; if (next_slot_ >= chunk_count_) next_slot_ = 0; - if (event_pending_[slot]) { + if (!stream_private_ && event_pending_[slot]) { // Rare case: another thread claimed this slot but has not yet registered a completion event for it. // No choice but to create a temporary allocation (or yield and wait for the completion event). return claim_unpooled(stream, len, out_mem); } - if (events_[slot]) { + if (stream_private_) { + // Per-stream zero pool. No event synchronization needed. + if (!in_order_) CHECK(stream->barrier()); + } else if (events_[slot]) { // Slot is claimed and has an outstanding event. Wait on it and delete it. stream->ctx().append_deps(*events_[slot]); events_[slot].reset(); @@ -175,7 +237,7 @@ void zero_pool_t::async_release(int token, const xpu::event_t &ev) { if (token >= 0) { std::lock_guard lock(mutex_); int slot = token; - events_[slot] = ev.clone(); + if (!stream_private_) events_[slot] = ev.clone(); event_pending_[slot] = false; } } diff --git a/src/gpu/intel/compute/zero_pool.hpp b/src/gpu/intel/compute/zero_pool.hpp index b7a9652d07d..66e2b42d4a1 100644 --- a/src/gpu/intel/compute/zero_pool.hpp +++ b/src/gpu/intel/compute/zero_pool.hpp @@ -32,7 +32,7 @@ namespace intel { class zero_pool_t { public: zero_pool_t(compute::compute_engine_t *engine, size_t chunk_size, - int chunk_count = 16); + bool stream_private = false, bool in_order = false); status_t init(); @@ -54,6 +54,8 @@ class zero_pool_t { std::unique_ptr mem_; size_t chunk_size_ = 0; int chunk_count_ = 0; + bool stream_private_ = false; + bool in_order_ = false; int clients_ = 0; int next_slot_ = 0; bool inited_ = false; @@ -66,7 +68,8 @@ class zero_pool_t { std::unique_ptr &out_mem); }; -status_t lookup_zero_pool(compute::compute_engine_t *engine, size_t chunk_size, +status_t lookup_zero_pool(compute::compute_engine_t *engine, + compute::compute_stream_t *stream, size_t chunk_size, zero_pool_t **out_pool); void release_zero_pool(zero_pool_t *pool); diff --git a/src/gpu/intel/jit/codegen/codegen.cpp b/src/gpu/intel/jit/codegen/codegen.cpp index f852c20e5f5..3f88cb8c35d 100644 --- a/src/gpu/intel/jit/codegen/codegen.cpp +++ b/src/gpu/intel/jit/codegen/codegen.cpp @@ -488,9 +488,7 @@ class ir_to_ngen_t : public ir_visitor_t { } else { ir_assert(dst.byte_offset() == src0.getByteOffset()) << "dst/src0 must be aligned to the same GRF offset."; - auto _src1 = src1; - auto _src2 = src2; - align_src_dst_offset(host_, scope, mod, dst, _src1, _src2); + align_src_dst_offset(host_, scope, mod, dst, src1, src2); if (hw < ngen::HW::XeLP && (ngen_is_dw(to_ngen(mad_func.dst_type)) || mad_func.dst_type == type_t::f64() @@ -501,15 +499,15 @@ class ir_to_ngen_t : public ir_visitor_t { (mad_func.exec_size * mad_func.dst_type.size()) / ngen::GRF::bytes(hw)); auto reg = tmp[0].setType(to_ngen(mad_func.dst_type)); - host_->mul(mod, reg, _src1, _src2); + host_->mul(mod, reg, src1, src2); host_->add(mod, dst, reg, src0); } else if (mad_func.dst_type == type_t::f64() - && _src1.reg_data().getHS() == 0 - && _src1.reg_data().getVS() == 0) { + && src1.reg_data().getHS() == 0 + && src1.reg_data().getVS() == 0) { // Workaround for sporadic f64 mad errors with broadcast src1 on XeHPC. - host_->mad(mod, dst, src0, _src2, _src1); + host_->mad(mod, dst, src0, src2, src1); } else { - host_->mad(mod, dst, src0, _src1, _src2); + host_->mad(mod, dst, src0, src1, src2); } } } @@ -1498,7 +1496,7 @@ class expr_evaluator_t : public ir_visitor_t { t_strided = tmp_strided.format(0, w_type, obj.elems(), w_stride); host_->emov(obj.elems(), t_strided, t); } else { - t_strided = t; + t_strided = std::move(t); } if (factor != 1) { host_->emul(obj.elems(), d, t_strided, ngen::Immediate(factor)); diff --git a/src/gpu/intel/jit/codegen/reorder.hpp b/src/gpu/intel/jit/codegen/reorder.hpp index c474c6575eb..bb76362b223 100644 --- a/src/gpu/intel/jit/codegen/reorder.hpp +++ b/src/gpu/intel/jit/codegen/reorder.hpp @@ -1308,7 +1308,7 @@ class reorder_2d_impl_t { x_stride, y_sub, y_stride); }); prev_layout = next_layout; - prev_rd = next_rd; + prev_rd = std::move(next_rd); } } diff --git a/src/gpu/intel/jit/conv/config.cpp b/src/gpu/intel/jit/conv/config.cpp index 369b0011e12..2f55822f641 100644 --- a/src/gpu/intel/jit/conv/config.cpp +++ b/src/gpu/intel/jit/conv/config.cpp @@ -30,6 +30,10 @@ #include "gpu/intel/jit/ir/tensor_config.hpp" #include "gpu/intel/jit/jit_eltwise_injector.hpp" +#define VDISPATCH_CHECK(pd, engine, cond, msg, ...) \ + VCONDCHECK(primitive, create, dispatch, convolution, (cond), \ + status::unimplemented, "%s," msg, pd->info(engine), ##__VA_ARGS__) + namespace dnnl { namespace impl { namespace gpu { @@ -132,10 +136,11 @@ bool is_small_oc(const conv_problem_t &prb) { } status_t conv_problem_t::init( - const impl::engine_t *engine, const convolution_pd_t *conv_pd) { + impl::engine_t *engine, const convolution_pd_t *conv_pd) { using namespace compute; - if (conv_pd->has_zero_dim_memory()) return status::unimplemented; + VDISPATCH_CHECK(conv_pd, engine, !conv_pd->has_zero_dim_memory(), + VERBOSE_EMPTY_TENSOR, ""); this->conv_pd = conv_pd; attr = conv_pd->attr(); @@ -205,81 +210,6 @@ status_t conv_problem_t::init( return status::success; } -bool can_reduce_to_1d(const memory_desc_t &out_md, const post_ops_t &post_ops) { - int ndims = out_md.ndims; - int sp_ndims = ndims - 2; - int non_one_sp_ndims = 0; - for (int i = ndims - sp_ndims; i < ndims; i++) { - if (out_md.dims[i] != 1) non_one_sp_ndims++; - } - if (non_one_sp_ndims == 1) return true; - for (int i = 0; i < post_ops.len(); i++) { - auto &po = post_ops.entry_[i]; - int mask = 0; - if (po.is_prelu()) { - mask = po.prelu.mask; - } else if (po.is_binary()) { - mask = utils::get_dims_mask( - out_md.dims, po.binary.src1_desc.dims, ndims); - } - // If the post-op is applied per D/H/W dimension then it cannot be - // transformed to 1D. - for (int i = ndims - sp_ndims; i < ndims; i++) { - if ((mask & (1 << i)) != 0) return false; - } - } - return true; -} - -void conv_problem_t::normalize_shape() { - bool is_1x1 = (kd * kh * kw == 1); - bool is_eq_oi = (od == id && oh == ih && ow == iw); - if (is_1x1 && is_stride1() && is_eq_oi - && can_reduce_to_1d(c_md(), conv_pd->attr()->post_ops_)) { - // Convert 3D to 1D convolution. - ir_assert(pd == 0 && ph == 0 && pw == 0); - ow = od * oh * ow; - iw = id * ih * iw; - od = id = kd = 1; - oh = ih = kh = 1; - dhw_map[0] = dhw_map[1] = dhw_map[2] = 2; - return; - } - // Propagate D -> H -> W. If the spatial dimension is not present, map it - // to the next present dimension. - std::vector xd = {&id, &od, &kd, &sd, &dd, &pd}; - std::vector xh = {&ih, &oh, &kh, &sh, &dh, &ph}; - std::vector xw = {&iw, &ow, &kw, &sw, &dw, &pw}; - std::vector x[3] = {xd, xh, xw}; - std::vector x_old[3]; - std::vector xdef = {1, 1, 1, 1, 0, 0}; - bool has_dim[3] = {false, false, false}; - for (int i = 0; i < 3; i++) { - x_old[i].resize(xdef.size()); - for (size_t j = 0; j < xdef.size(); j++) { - if (*x[i][j] != xdef[j]) has_dim[i] = true; - x_old[i][j] = *x[i][j]; - } - } - auto set = [](const std::vector &x, const std::vector &values) { - for (size_t i = 0; i < x.size(); i++) - *x[i] = values[i]; - }; - if (!has_dim[0] && !has_dim[1] && !has_dim[2]) has_dim[2] = true; - int sp_count = (int)has_dim[0] + (int)has_dim[1] + (int)has_dim[2]; - int shift = 3 - sp_count; - for (int i = 0, idx = 0; i < 3; i++) { - if (has_dim[i]) dhw_map[i] = shift + idx++; - set(x[i], xdef); - } - for (int i = 0; i < 3; i++) { - if (dhw_map[i] != -1) set(x[dhw_map[i]], x_old[i]); - } - if (!has_dim[2]) dhw_map[2] = 2; - if (!has_dim[1]) dhw_map[1] = dhw_map[2]; - if (!has_dim[0]) dhw_map[0] = dhw_map[1]; -} - std::string conv_problem_t::desc_str(bool print_mb) const { std::ostringstream oss; if (print_mb) oss << "mb" << mb; @@ -690,12 +620,13 @@ void init_data_tags(const conv_config_t &cfg, const memory_desc_t &src_md, // Use plain tags for user-facing activations for small-channel tensors. if (!matches_tag(src_md, src_tag) && is_small_ic_g1) - user_src_tag = (user_src_req.empty() ? "axb" : user_src_req); + user_src_tag = (user_src_req.empty() ? "axb" : std::move(user_src_req)); if (!matches_tag(dst_md, dst_tag) && is_small_oc_g1) - user_dst_tag = (user_dst_req.empty() ? "axb" : user_dst_req); + user_dst_tag = (user_dst_req.empty() ? "axb" : std::move(user_dst_req)); // Avoid reorder for small shapes - if (prb.g == 1 && prb.ic < 4 && prb.oc < 4 && prb.mb < 4 && prb.ksp == 1) { + if (!user_src_tag.empty() && !user_dst_tag.empty() && prb.g == 1 + && prb.ic < 4 && prb.oc < 4 && prb.mb < 4 && prb.ksp == 1) { src_tag = user_src_tag; dst_tag = user_dst_tag; } @@ -713,7 +644,8 @@ void init_data_tags(const conv_config_t &cfg, const memory_desc_t &src_md, if (dst_abx && !dst_matches) user_dst_tag = "abx"; } -status_t init_tensor_layouts(conv_config_t &cfg, convolution_pd_t *pd) { +status_t init_tensor_layouts( + conv_config_t &cfg, convolution_pd_t *pd, impl::engine_t *engine) { const auto &prb = cfg.prb(); // Compute layout tags and user layout tags. If a compute layout is // different from a user layout then an extra pre/post reorder will be @@ -773,12 +705,18 @@ status_t init_tensor_layouts(conv_config_t &cfg, convolution_pd_t *pd) { layout_t user_bia_layout; if (prb.with_bias) user_bia_layout = init_layout(bia_md, user_bia_tag); - if (!user_src_layout.is_strictly_equal(make_layout(src_md, user_src_tag))) - return status::unimplemented; - if (!user_dst_layout.is_strictly_equal(make_layout(dst_md, user_dst_tag))) - return status::unimplemented; - if (!user_wei_layout.is_strictly_equal(make_layout(wei_md, user_wei_tag))) - return status::unimplemented; + VDISPATCH_CHECK(pd, engine, + user_src_layout.is_strictly_equal( + make_layout(src_md, user_src_tag)), + VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_CHECK(pd, engine, + user_dst_layout.is_strictly_equal( + make_layout(dst_md, user_dst_tag)), + VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_CHECK(pd, engine, + user_wei_layout.is_strictly_equal( + make_layout(wei_md, user_wei_tag)), + VERBOSE_UNSUPPORTED_TAG); auto src_layout = (src_tag != user_src_tag) ? make_layout(src_md, src_tag) : user_src_layout; @@ -964,14 +902,16 @@ bool should_use_mad(const conv_problem_t &prb) { return prb.is_dw || small_ic_oc || grouped_small_ic_oc; } -status_t init_fma_kind(conv_config_t &cfg) { +status_t init_fma_kind( + conv_config_t &cfg, convolution_pd_t *pd, impl::engine_t *engine) { if (cfg.fma_kind_param().is_overridden()) return status::success; const auto &prb = cfg.prb(); auto fma_kind = get_supported_fma_kind( cfg.hw(), prb.a_data_type, prb.b_data_type, prb.acc_data_type); // Force mad for some cases if (should_use_mad(prb)) fma_kind = fma_kind_t::mad; - if (fma_kind == fma_kind_t::undef) return status::unimplemented; + VDISPATCH_CHECK(pd, engine, fma_kind != fma_kind_t::undef, + VERBOSE_UNSUPPORTED_DT_CFG); cfg.set_fma_kind(fma_kind); return status::success; } @@ -1074,14 +1014,15 @@ void init_bwd_d_optimize(conv_config_t &cfg) { } status_t init_pd_time_cfg(const conv_problem_t &prb, conv_config_t &cfg, - const impl::engine_t *engine, convolution_pd_t *pd, - primitive_attr_t *attr) { + impl::engine_t *engine, convolution_pd_t *pd, primitive_attr_t *attr) { hw_t hw(engine); - if (!hw_ok(hw)) return status::unimplemented; - if (!data_types_ok(prb, hw)) return status::unimplemented; - if (!post_ops_ok(prb, hw)) return status::unimplemented; - if (!zero_points_ok(prb)) return status::unimplemented; + VDISPATCH_CHECK(pd, engine, hw_ok(hw), VERBOSE_UNSUPPORTED_ISA); + VDISPATCH_CHECK(pd, engine, data_types_ok(prb, hw), VERBOSE_UNSUPPORTED_DT); + VDISPATCH_CHECK( + pd, engine, post_ops_ok(prb, hw), VERBOSE_UNSUPPORTED_POSTOP); + VDISPATCH_CHECK( + pd, engine, zero_points_ok(prb), VERBOSE_UNSUPPORTED_ZP_CFG); zero_points_config_t zp_cfg(pd); cfg.set_zp_cfg(zp_cfg); @@ -1089,14 +1030,15 @@ status_t init_pd_time_cfg(const conv_problem_t &prb, conv_config_t &cfg, cfg.set_exec_cfg(exec_config_t(hw)); cfg.maybe_override_from_env(); - CHECK(init_fma_kind(cfg)); + CHECK(init_fma_kind(cfg, pd, engine)); CHECK(init_simd(cfg)); CHECK(init_vec_size(cfg)); - CHECK(init_tensor_layouts(cfg, pd)); + CHECK(init_tensor_layouts(cfg, pd, engine)); CHECK(attr->set_default_formats(&prb.c_md())); - if (!post_op_layouts_ok(prb)) return status::unimplemented; + VDISPATCH_CHECK( + pd, engine, post_op_layouts_ok(prb), VERBOSE_UNSUPPORTED_POSTOP); init_bwd_d_optimize(cfg); @@ -1577,7 +1519,7 @@ walk_order_t compute_walk_order(const conv_config_t &cfg) { auto outer = grid_inner; outer[entry.dim] = std::min(rem_tile[entry.dim], entry.size); size_t ab_bytes = get_memory_footprint(cfg, inner, outer); - if (ab_bytes <= l3_size) grid_inner = outer; + if (ab_bytes <= l3_size) grid_inner = std::move(outer); } // Add the blocks in this order: // - Step 1. Add grid_inner blocks (fitting L3 cache) @@ -1813,7 +1755,7 @@ status_t init_cfg(conv_config_t &cfg, const primitive_t *prim) { auto try_cfg = cfg; auto status = try_init_cfg(try_cfg); if (status == status::success) { - cfg = try_cfg; + cfg = std::move(try_cfg); return status::success; } } diff --git a/src/gpu/intel/jit/conv/config.hpp b/src/gpu/intel/jit/conv/config.hpp index c468cddcf04..6141c353ca5 100644 --- a/src/gpu/intel/jit/conv/config.hpp +++ b/src/gpu/intel/jit/conv/config.hpp @@ -647,8 +647,7 @@ class bmnk_dim_helper_t { }; status_t init_pd_time_cfg(const conv_problem_t &prb, conv_config_t &cfg, - const impl::engine_t *engine, convolution_pd_t *pd, - primitive_attr_t *attr); + impl::engine_t *engine, convolution_pd_t *pd, primitive_attr_t *attr); status_t init_cfg(conv_config_t &cfg, const primitive_t *prim); status_t init_regs(conv_config_t &cfg); int slm_bufs_hint(const conv_problem_t &prb, int m_tg, int n_tg, diff --git a/src/gpu/intel/jit/conv/ir_builder.cpp b/src/gpu/intel/jit/conv/ir_builder.cpp index d08c5b0a41b..9aa62da1ea9 100644 --- a/src/gpu/intel/jit/conv/ir_builder.cpp +++ b/src/gpu/intel/jit/conv/ir_builder.cpp @@ -159,7 +159,7 @@ expr_t add_grid_guard( if (tg[i] == load[i]) continue; auto i_cond = (tg.idx(i) < load[i]); if (cond.is_empty()) { - cond = i_cond; + cond = std::move(i_cond); } else { cond = cond & i_cond; } @@ -651,7 +651,7 @@ void conv_ir_builder_t::build() { c_store_stmt = c_store_stmt.append(cb.c_store_stmt()); c_store_stmt = stmt_group_t::make(stmt_label_t::c_store(), c_store_stmt); - stmt_ = loop_stmt; + stmt_ = std::move(loop_stmt); stmt_ = stmt_seq_t::make(cb.zero_out_stmt(), stmt_); stmt_ = stmt_seq_t::make(stmt_, c_store_stmt); diff --git a/src/gpu/intel/jit/conv/pipeline.cpp b/src/gpu/intel/jit/conv/pipeline.cpp index 141dab885d4..da99085b95f 100644 --- a/src/gpu/intel/jit/conv/pipeline.cpp +++ b/src/gpu/intel/jit/conv/pipeline.cpp @@ -387,7 +387,7 @@ class compute_step_t { } std::reverse(new_lets.begin(), new_lets.end()); - inner_let_stmts_ = new_lets; + inner_let_stmts_ = std::move(new_lets); } template @@ -395,6 +395,7 @@ class compute_step_t { const std::vector &let_infos, bool is_preload, bool is_mul) { std::vector ret; + ret.reserve(vec.size()); for (auto &v : vec) ret.push_back(update_var(v, let_infos, is_preload, is_mul)); return ret; @@ -861,7 +862,7 @@ class sbid_manager_t { } } - entries_[old_idx] = entry_t({key, cur_time_++}); + entries_[old_idx] = entry_t({std::move(key), cur_time_++}); return ngen_proxy::SBID(old_idx); } @@ -1006,8 +1007,8 @@ struct pipeline_ctx_t { stmt_t body_; }; -pipeline_ctx_t pipeline( - int length, const loop_info_t &loop, stmt_t A_block, stmt_t B_block) { +pipeline_ctx_t pipeline(int length, const loop_info_t &loop, + const stmt_t &A_block, const stmt_t &B_block) { expr_t idx = loop.var; int bound = loop.bound(); @@ -1045,7 +1046,7 @@ class prefetch_pipeliner_t { auto &loops = loop_nest.loops(); // No loops to pipeline - if (loops.size() == 0) return root_; + if (loops.empty()) return root_; auto &loop_body = loops[0].body(); auto A_block_stmt @@ -1397,7 +1398,7 @@ class simple_slm_buffering_injector_t { g2s_store = g2s_store.append(slm_idx_update); auto s2r_mul_body = s2r_mul; - auto s2r_mul_tail = s2r_mul; + auto s2r_mul_tail = std::move(s2r_mul); auto slm_counter = slm_idx_load(2, 1); auto cond = (slm_counter >= cfg_.slm().bufs() - 1); @@ -1875,8 +1876,8 @@ class unrolling_injector_t { if (!seen_dst.insert(dst).second) continue; - auto new_call = func_call_t::make( - call.func, {dst, src0, src1, src2}, call.attr); + auto new_call = func_call_t::make(call.func, + {dst, std::move(src0), src1, src2}, call.attr); ret = substitute(ret, s, new_call, 1); } else if (is_func_call(s)) { auto &call = s.as(); @@ -1888,8 +1889,8 @@ class unrolling_injector_t { if (!seen_dst.insert(dst).second) continue; - auto new_call = func_call_t::make( - call.func, {dst, src0, src1, src2}, call.attr); + auto new_call = func_call_t::make(call.func, + {dst, std::move(src0), src1, src2}, call.attr); ret = substitute(ret, s, new_call, 1); } } diff --git a/src/gpu/intel/jit/conv/plan.cpp b/src/gpu/intel/jit/conv/plan.cpp index aeddbfb189a..5ee61263a21 100644 --- a/src/gpu/intel/jit/conv/plan.cpp +++ b/src/gpu/intel/jit/conv/plan.cpp @@ -102,10 +102,10 @@ static dim_tile_t create_tile(gemm_schedule_t &gemm_schedule, auto outer_name = (i == 1) ? dim_name + suffixes[i - 1] : std::string(); auto inner_name = dim_name + suffixes[i]; gemm_schedule.split(idx, dims[i], outer, inner, outer_name, inner_name); - if (has_block(i)) idxs[i] = inner; - idx = outer; + if (has_block(i)) idxs[i] = std::move(inner); + idx = std::move(outer); } - idxs[0] = idx; + idxs[0] = std::move(idx); tile.set_grid_idx(idxs[0]); tile.set_loop_idx(idxs[1]); @@ -300,8 +300,9 @@ void init_fwd(const conv_config_t &cfg_, gemm_schedule_t &gemm_schedule, gemm_schedule.tensorize(kw_tile.iter_idx()); gemm_schedule.tensorize(ic_tile.iter_idx()); - gemm_schedule.reorder({ic_tile.loop_idx(), kd, kh, kw_tile.loop_idx(), - oc_tile.tg_idx(), mb_ow_tg_idx, ic_tile.tg_idx()}); + gemm_schedule.reorder({ic_tile.loop_idx(), std::move(kd), std::move(kh), + kw_tile.loop_idx(), oc_tile.tg_idx(), std::move(mb_ow_tg_idx), + ic_tile.tg_idx()}); } void init_bwd_d(const conv_config_t &cfg_, gemm_schedule_t &gemm_schedule, @@ -482,7 +483,8 @@ void init_bwd_d(const conv_config_t &cfg_, gemm_schedule_t &gemm_schedule, switch (cfg_.bwd_d_optimize_kind()) { case bwd_d_optimize_kind_t::none: - gemm_schedule.reorder({oc_tile.loop_idx(), kd, kh, kw}); + gemm_schedule.reorder({oc_tile.loop_idx(), std::move(kd), + std::move(kh), std::move(kw)}); break; case bwd_d_optimize_kind_t::skip_strided_dhw: gemm_schedule.set_dynamic_bounds( @@ -493,7 +495,8 @@ void init_bwd_d(const conv_config_t &cfg_, gemm_schedule_t &gemm_schedule, gemm_schedule.set_dynamic_bounds( kh, (ih + prb_.ph) % prb_.sh, prb_.sh); // Put kd/kh/kw outermost to allow pipelining in oc loop. - gemm_schedule.reorder({kd, kh, kw, oc_tile.loop_idx()}); + gemm_schedule.reorder({std::move(kd), std::move(kh), std::move(kw), + oc_tile.loop_idx()}); break; case bwd_d_optimize_kind_t::skip_out_of_bound_w: gemm_schedule.set_dynamic_bounds(kw, @@ -2433,10 +2436,10 @@ class plan_builder_t { return plan_status_t::invalid_c_layout; } - plan.a_layout = a_layout; - plan.b_layout = b_layout; - plan.c_layout = c_layout; - plan.c_prb_layout = c_prb_layout; + plan.a_layout = std::move(a_layout); + plan.b_layout = std::move(b_layout); + plan.c_layout = std::move(c_layout); + plan.c_prb_layout = std::move(c_prb_layout); plan.fma_kind = fma_kind; plan.b_blk = b_blk; plan.m_blk = m_blk; diff --git a/src/gpu/intel/jit/conv/problem.cpp b/src/gpu/intel/jit/conv/problem.cpp index bd179698df1..05875727c46 100644 --- a/src/gpu/intel/jit/conv/problem.cpp +++ b/src/gpu/intel/jit/conv/problem.cpp @@ -168,6 +168,38 @@ const std::vector &conv_padding_dims() { return _padding_dims; } +bool can_reduce_to_1d(const memory_desc_t &out_md, const post_ops_t &post_ops) { + int ndims = out_md.ndims; + int sp_ndims = ndims - 2; + int non_one_sp_ndims = 0; + for (int i = ndims - sp_ndims; i < ndims; i++) { + if (out_md.dims[i] != 1) non_one_sp_ndims++; + } + if (non_one_sp_ndims == 1) return true; + for (int i = 0; i < post_ops.len(); i++) { + auto &po = post_ops.entry_[i]; + int mask = 0; + if (po.is_prelu()) { + mask = po.prelu.mask; + } else if (po.is_binary()) { + mask = utils::get_dims_mask( + out_md.dims, po.binary.src1_desc.dims, ndims); + } + // If the post-op is applied per D/H/W dimension then it cannot be + // transformed to 1D. + for (int i = ndims - sp_ndims; i < ndims; i++) { + if ((mask & (1 << i)) != 0) return false; + } + } + return true; +} + +void conv_problem_t::normalize_shape() { + normalize_conv_shape(id, od, kd, sd, dd, pd, ih, oh, kh, sh, dh, ph, iw, ow, + kw, sw, dw, pw, + can_reduce_to_1d(c_md(), conv_pd->attr()->post_ops_), dhw_map); +} + const memory_desc_t &conv_problem_t::a_md() const { return *pick_a(conv_pd->invariant_src_md(), conv_pd->invariant_wei_md(), conv_pd->invariant_dst_md()); @@ -285,6 +317,60 @@ void conv_problem_t::init_transpose(const hw_t &hw) { = gpu_utils::dev_getenv("ab_swap_transpose", ab_swap_transpose); } +void normalize_conv_shape(int &id, int &od, int &kd, int &sd, int &dd, int &pd, + int &ih, int &oh, int &kh, int &sh, int &dh, int &ph, int &iw, int &ow, + int &kw, int &sw, int &dw, int &pw, bool can_flatten_spatial, + std::array &dhw_map) { + for (int i = 0; i < 3; i++) + dhw_map[i] = -1; + bool is_1x1 = (kd * kh * kw == 1); + bool is_eq_oi = (od == id && oh == ih && ow == iw); + if (is_1x1 && sd == 1 && sh == 1 && sw == 1 && is_eq_oi + && can_flatten_spatial) { + // Convert 3D to 1D convolution. + ir_assert(pd == 0 && ph == 0 && pw == 0); + ow = od * oh * ow; + iw = id * ih * iw; + od = id = kd = 1; + oh = ih = kh = 1; + dhw_map[0] = dhw_map[1] = dhw_map[2] = 2; + return; + } + // Propagate D -> H -> W. If the spatial dimension is not present, map it + // to the next present dimension. + std::vector xd = {&id, &od, &kd, &sd, &dd, &pd}; + std::vector xh = {&ih, &oh, &kh, &sh, &dh, &ph}; + std::vector xw = {&iw, &ow, &kw, &sw, &dw, &pw}; + std::vector x[3] = {std::move(xd), std::move(xh), std::move(xw)}; + std::vector x_old[3]; + std::vector xdef = {1, 1, 1, 1, 0, 0}; + bool has_dim[3] = {false, false, false}; + for (int i = 0; i < 3; i++) { + x_old[i].resize(xdef.size()); + for (size_t j = 0; j < xdef.size(); j++) { + if (*x[i][j] != xdef[j]) has_dim[i] = true; + x_old[i][j] = *x[i][j]; + } + } + auto set = [](const std::vector &x, const std::vector &values) { + for (size_t i = 0; i < x.size(); i++) + *x[i] = values[i]; + }; + if (!has_dim[0] && !has_dim[1] && !has_dim[2]) has_dim[2] = true; + int sp_count = (int)has_dim[0] + (int)has_dim[1] + (int)has_dim[2]; + int shift = 3 - sp_count; + for (int i = 0, idx = 0; i < 3; i++) { + if (has_dim[i]) dhw_map[i] = shift + idx++; + set(x[i], xdef); + } + for (int i = 0; i < 3; i++) { + if (dhw_map[i] != -1) set(x[dhw_map[i]], x_old[i]); + } + if (!has_dim[2]) dhw_map[2] = 2; + if (!has_dim[1]) dhw_map[1] = dhw_map[2]; + if (!has_dim[0]) dhw_map[0] = dhw_map[1]; +} + prb_dim_t to_gemm(const prb_dim_t &d, prop_kind_t prop, bool is_transpose) { const bool is_fwd = (prop == prop_kind::forward); const bool is_bwd_d = (prop == prop_kind::backward_data); diff --git a/src/gpu/intel/jit/conv/problem.hpp b/src/gpu/intel/jit/conv/problem.hpp index 3c0237f3366..29ae0dc625f 100644 --- a/src/gpu/intel/jit/conv/problem.hpp +++ b/src/gpu/intel/jit/conv/problem.hpp @@ -70,8 +70,7 @@ class conv_problem_t { public: conv_problem_t() = default; - status_t init( - const impl::engine_t *engine, const convolution_pd_t *conv_pd); + status_t init(impl::engine_t *engine, const convolution_pd_t *conv_pd); bool is_stride1() const { return sd == 1 && sh == 1 && sw == 1; } @@ -200,6 +199,10 @@ class conv_problem_t { void init_transpose(const hw_t &hw); }; +void normalize_conv_shape(int &id, int &od, int &kd, int &sd, int &dd, int &pd, + int &ih, int &oh, int &kh, int &sh, int &dh, int &ph, int &iw, int &ow, + int &kw, int &sw, int &dw, int &pw, bool can_flatten_spatial, + std::array &dhw_map); bool is_small_ic(const conv_problem_t &prb); class conv_arg_helper_t { diff --git a/src/gpu/intel/jit/conv/tiler.cpp b/src/gpu/intel/jit/conv/tiler.cpp index ea122142d10..c48e07686bb 100644 --- a/src/gpu/intel/jit/conv/tiler.cpp +++ b/src/gpu/intel/jit/conv/tiler.cpp @@ -100,7 +100,7 @@ struct x2_tile_info_t { for (int i : factors) { int j = ij / i; if (d0.is_iter_ok(i) && d1.is_iter_ok(j)) { - ret.push_back(std::make_pair(i, j)); + ret.emplace_back(i, j); } } } @@ -804,7 +804,7 @@ class conv_blocking_checker_t : public blocking_checker_t { if (blk.thread_group().has(d)) blocks.emplace_back( level_kind_t::thread_group, blk.thread_group_dim(d)); - if (!layout_dim_ok(prop, tensor_kind, layout, d, blocks)) + if (!layout_dim_ok(prop, tensor_kind, layout, d, std::move(blocks))) return false; } return true; @@ -991,6 +991,7 @@ prb_dim_t select_non_blocked_iter_dim( const conv_config_t &cfg, const std::vector &dims) { const auto shape = cfg.shape(/*pad=*/false); std::vector scores; + scores.reserve(dims.size()); for (auto &d : dims) scores.push_back(get_iter_dim_score(d, cfg, shape[d])); auto max_it = std::max_element(scores.begin(), scores.end()); @@ -1011,6 +1012,7 @@ prb_dim_t select_iter_dim( if (dims.size() == 1) return dims[0]; std::vector dim_blocks; + dim_blocks.reserve(dims.size()); for (auto &d : dims) { dim_blocks.push_back(inner_block(cfg, d)); } @@ -1288,9 +1290,19 @@ class conv_tuner_t { params_generator_t params_gen( tune_level, simd_size, chk, level_tile_sets); - params_distance_t dist(params_gen.params_vec(), convert); - auto ret = conv2tuner_.emplace( - key, conv_tuner_t(key, ops, params_gen, dist)); + std::vector> tiles; + for (auto &p : params_gen.params_vec()) { + auto &b = p.blocking(); + std::vector p_tiles; + p_tiles.push_back(convert(b.iter())); + p_tiles.push_back(convert(b.thread_group())); + p_tiles.push_back(convert(b.loop())); + tiles.push_back(std::move(p_tiles)); + } + tile_to_vec_t tile_to_vec(tiles); + auto ret = conv2tuner_.emplace(key, + conv_tuner_t(key, ops, std::move(params_gen), + std::move(tile_to_vec))); return &ret.first->second; } @@ -1316,10 +1328,10 @@ class conv_tuner_t { private: conv_tuner_t(const conv_key_t &key, double ops, - params_generator_t params_gen, params_distance_t params_dist) + params_generator_t params_gen, tile_to_vec_t tile_vec) : key_(key) , params_gen_(std::move(params_gen)) - , params_dist_(std::move(params_dist)) + , tile_vec_(std::move(tile_vec)) , ops_(ops) { params_gen_.shuffle(conv_key_hash_t()(key_)); } @@ -1368,7 +1380,7 @@ class conv_tuner_t { auto &p = params_gen_.at(i); dists[p.id()] = std::numeric_limits::max(); for (int id : best_ids) { - float d = params_dist_.dist(id, p.id()); + float d = tile_vec_.dist(id, p.id()); dists[p.id()] = std::min(dists[p.id()], d); } } @@ -1384,7 +1396,7 @@ class conv_tuner_t { conv_key_t key_; params_generator_t params_gen_; - const params_distance_t params_dist_; + const tile_to_vec_t tile_vec_; tune_data_t tune_data_; blocking_params_t best_params_dbg_; diff --git a/src/gpu/intel/jit/conv/zp_plan.cpp b/src/gpu/intel/jit/conv/zp_plan.cpp index e454e677f4d..65d7cc0a785 100644 --- a/src/gpu/intel/jit/conv/zp_plan.cpp +++ b/src/gpu/intel/jit/conv/zp_plan.cpp @@ -674,7 +674,7 @@ class zp_comp_init_plan_t : public base_plan_t { auto _1x4_type = type_t::s32(); auto dp4a = dpas_t::make_dp4a(simd_, comp_type, wei_type, _1x4_type); auto zp_1x4 = buf_mgr.get("zp_1x4", grf_size()); - return dp4a.call({comp, comp, wei, zp_1x4}); + return dp4a.call({comp, comp, wei, std::move(zp_1x4)}); } stmt_t create_tile_wei_Xn4k_x8_zp_per_k(const expr_t &zp, const expr_t &wei, diff --git a/src/gpu/intel/jit/emulation.hpp b/src/gpu/intel/jit/emulation.hpp index ad20bb498e3..c71c6377f13 100644 --- a/src/gpu/intel/jit/emulation.hpp +++ b/src/gpu/intel/jit/emulation.hpp @@ -180,9 +180,11 @@ struct EmulationImplementation { if (isQ || isUQ) { outLo = uint32_t(static_cast(in)); + outLo = outLo.forceInt32(); outLo.setType(DataType::ud); outHi = uint32_t(static_cast(in) >> 32); + outHi = outHi.forceInt32(); outHi.setType(isQ ? DataType::d : DataType::ud); } else { outLo = in; @@ -588,8 +590,31 @@ struct EmulationImplementation { bool emulate64 = strategy.emulate64_mul; - if (s0Q || s1Q) { + if (s0Q) { stub(); + } else if (s1Q) { + if (!s0D || !dstQ) stub(); + auto s0Type = src0.getType(); + ngen::RegData dstLo, dstHi; + S1 s1Hi, s1Lo; + splitToDW(dst, dstLo, dstHi); + splitToDW(src1, s1Lo, s1Hi); + s1Hi = expandDW(s1Hi); + s1Lo = expandDW(s1Lo); + dstLo.setType(src0.getType()); + dstHi.setType(src0.getType()); + auto s1W0 = lowWord(s1Lo); + auto s1W2 = lowWord(s1Hi); + auto accLo + = g.acc0.retype(s0Type)[dstLo.getOffset()](dstLo.getHS()); + auto accHi + = g.acc0.retype(s0Type)[dstHi.getOffset()](dstHi.getHS()); + g.mul(mod, accHi, src0, s1W2); + g.macl(mod, dstHi, src0, s1Hi); + g.mul(mod, accLo, src0, s1W0); + g.mach(mod, dstLo, src0, s1Lo); + g.add(mod, dstHi, dstHi, dstLo); + g.mov(mod, dstLo, accLo); } else if (dstQ && s0W && s1W) { RegData dstLo, dstHi; splitToDW(dst, dstLo, dstHi); diff --git a/src/gpu/intel/jit/gemm/gen_gemm.cpp b/src/gpu/intel/jit/gemm/gen_gemm.cpp index 1bcd63c131c..1545e7ce2f1 100644 --- a/src/gpu/intel/jit/gemm/gen_gemm.cpp +++ b/src/gpu/intel/jit/gemm/gen_gemm.cpp @@ -31,17 +31,18 @@ namespace intel { namespace jit { status_t gen_gemm_t::launch_nocopy(const gemm_exec_ctx_t &ctx, - compute::compute_stream_t *compute_stream, const memory_storage_t &a, - const memory_storage_t &b, const memory_storage_t &c, - const memory_storage_t *ao, const memory_storage_t *bo, - const memory_storage_t *a_scales, const memory_storage_t *b_scales, - const memory_storage_t &co, const memory_storage_t *c_temp, - int po_count, const memory_storage_t **po_srcs, int64_t offset_a, - int64_t offset_b, int64_t offset_c, int32_t offset_aq, - int32_t offset_bq, int32_t offset_co, int32_t *offset_po_src, - int32_t lda, int32_t ldb, int32_t ldc, int32_t m, int32_t n, int32_t k, - int32_t k0, float alpha, float beta, int32_t cmask, bool last_k_block, - bool swapab, bool disable_hilbert) const { + compute::compute_stream_t *compute_stream, zero_pool_t *zero_pool, + const memory_storage_t &a, const memory_storage_t &b, + const memory_storage_t &c, const memory_storage_t *ao, + const memory_storage_t *bo, const memory_storage_t *a_scales, + const memory_storage_t *b_scales, const memory_storage_t &co, + const memory_storage_t *c_temp, int po_count, + const memory_storage_t **po_srcs, int64_t offset_a, int64_t offset_b, + int64_t offset_c, int32_t offset_aq, int32_t offset_bq, + int32_t offset_co, int32_t *offset_po_src, int32_t lda, int32_t ldb, + int32_t ldc, int32_t m, int32_t n, int32_t k, int32_t k0, float alpha, + float beta, int32_t cmask, bool last_k_block, bool swapab, + bool disable_hilbert) const { if (pd()->desc()->batch() == 0) return status::success; uint32_t flags = 0; @@ -121,7 +122,7 @@ status_t gen_gemm_t::launch_nocopy(const gemm_exec_ctx_t &ctx, std::unique_ptr zeros; int zp_token = 0; if (nocopy_info()->fusedBeta() || nocopy_info()->fusedPostOps()) { - CHECK(zero_pool_->claim( + CHECK(zero_pool->claim( compute_stream, zero_pool_bytes_, zeros, &zp_token)); arg_list.set(argn++, *zeros); } @@ -217,7 +218,7 @@ status_t gen_gemm_t::launch_nocopy(const gemm_exec_ctx_t &ctx, auto status = parallel_for(ctx, nd_range, nocopy_kernel_, arg_list); if (nocopy_info()->fusedBeta() || nocopy_info()->fusedPostOps()) - zero_pool_->async_release(zp_token, compute_stream->ctx().get_deps()); + zero_pool->async_release(zp_token, compute_stream->ctx().get_deps()); return status; } @@ -226,6 +227,17 @@ status_t gen_gemm_t::execute(const gemm_exec_ctx_t &ctx) const { auto *compute_stream = utils::downcast(ctx.stream()); + auto zero_pool = zero_pool_; + +#ifdef DNNL_WITH_SYCL + if (!zero_pool) { + auto *compute_engine = utils::downcast( + ctx.stream()->engine()); + CHECK(lookup_zero_pool(compute_engine, compute_stream, + zero_pool_chunk_size_, &zero_pool)); + } +#endif + const auto d = pd()->desc(); const auto &problem = *pd()->kernel_desc()->problem(); @@ -392,9 +404,9 @@ status_t gen_gemm_t::execute(const gemm_exec_ctx_t &ctx) const { if (k_parallel_global && !nocopy_info()->fusedBeta() && beta != 1.0f && (k > dim_t(k0) * pd()->kernel_desc()->aux_params()->wgK)) { - status = launch_nocopy(ctx, compute_stream, a, b, c, ao, bo, - a_scales, b_scales, *co, nullptr, po_count, po_srcs, off_a0, - off_b0, off_c0, int32_t(off_aq0), int32_t(off_bq0), + status = launch_nocopy(ctx, compute_stream, zero_pool, a, b, c, ao, + bo, a_scales, b_scales, *co, nullptr, po_count, po_srcs, + off_a0, off_b0, off_c0, int32_t(off_aq0), int32_t(off_bq0), int32_t(off_co0), po_offsets0, lda, ldb, ldc, m, n, 0, 1, 1.0f, beta, 0, false, swapab, true); if (status) return status; @@ -454,8 +466,8 @@ status_t gen_gemm_t::execute(const gemm_exec_ctx_t &ctx) const { } float eff_beta = (Bk == 0) ? beta : 1.0f; - status = launch_nocopy(ctx, compute_stream, a, b, c, ao, bo, - a_scales, b_scales, *co, c_temp.get(), po_count, + status = launch_nocopy(ctx, compute_stream, zero_pool, a, b, c, + ao, bo, a_scales, b_scales, *co, c_temp.get(), po_count, po_srcs, off_a_src, off_b_src, off_c, off_aq, off_bq, off_co, po_offsets, lda, ldb, ldc, size_m, size_n, size_k, k0, alpha, eff_beta, cmask, last_k_block, diff --git a/src/gpu/intel/jit/gemm/gen_gemm.hpp b/src/gpu/intel/jit/gemm/gen_gemm.hpp index 53a1661445b..0ed13782274 100644 --- a/src/gpu/intel/jit/gemm/gen_gemm.hpp +++ b/src/gpu/intel/jit/gemm/gen_gemm.hpp @@ -647,9 +647,6 @@ struct gen_gemm_t : public gpu_gemm_t { const auto *info = nocopy_info(); if (info->fusedBeta() || info->fusedPostOps()) { - auto *compute_engine - = utils::downcast(engine); - int zg_cl = 0; if (info->fusedBeta()) zg_cl++; if (info->fusedPostOps()) zg_cl++; @@ -657,9 +654,14 @@ struct gen_gemm_t : public gpu_gemm_t { zero_pool_bytes_ = pd()->max_k_sliced_groups() * 64 * zg_cl; auto zg_max = pd()->dev_info_->hw_threads(false); - auto zg_bytes_max = zg_max * 2 * 2 * 64; + zero_pool_chunk_size_ = zg_max * 2 * 2 * 64; - CHECK(lookup_zero_pool(compute_engine, zg_bytes_max, &zero_pool_)); +#ifndef DNNL_WITH_SYCL + auto *compute_engine + = utils::downcast(engine); + CHECK(lookup_zero_pool(compute_engine, nullptr, + zero_pool_chunk_size_, &zero_pool_)); +#endif nocopy_kernel_.save_output_events(); } @@ -671,17 +673,18 @@ struct gen_gemm_t : public gpu_gemm_t { private: status_t launch_nocopy(const gemm_exec_ctx_t &ctx, - compute::compute_stream_t *s, const memory_storage_t &a, - const memory_storage_t &b, const memory_storage_t &c, - const memory_storage_t *ao, const memory_storage_t *bo, - const memory_storage_t *a_scales, const memory_storage_t *b_scales, - const memory_storage_t &co, const memory_storage_t *c_temp, - int po_count, const memory_storage_t **po_src, int64_t offset_a, - int64_t offset_b, int64_t offset_c, int32_t offset_aq, - int32_t offset_bq, int32_t offset_co, int32_t *offset_po_src, - int32_t lda, int32_t ldb, int32_t ldc, int32_t m, int32_t n, - int32_t k, int32_t k0, float alpha, float beta, int32_t cmask, - bool last_k_block, bool swapab, bool disable_hilbert) const; + compute::compute_stream_t *s, zero_pool_t *zero_pool, + const memory_storage_t &a, const memory_storage_t &b, + const memory_storage_t &c, const memory_storage_t *ao, + const memory_storage_t *bo, const memory_storage_t *a_scales, + const memory_storage_t *b_scales, const memory_storage_t &co, + const memory_storage_t *c_temp, int po_count, + const memory_storage_t **po_src, int64_t offset_a, int64_t offset_b, + int64_t offset_c, int32_t offset_aq, int32_t offset_bq, + int32_t offset_co, int32_t *offset_po_src, int32_t lda, int32_t ldb, + int32_t ldc, int32_t m, int32_t n, int32_t k, int32_t k0, + float alpha, float beta, int32_t cmask, bool last_k_block, + bool swapab, bool disable_hilbert) const; const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } const CommonDriverInfo *nocopy_info() const { @@ -692,6 +695,7 @@ struct gen_gemm_t : public gpu_gemm_t { compute::scalar_type_t scalar_type_; zero_pool_t *zero_pool_ = nullptr; size_t zero_pool_bytes_ = 0; + size_t zero_pool_chunk_size_ = 0; }; } // namespace jit diff --git a/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp b/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp index ff56ba2f297..8eba841ea12 100644 --- a/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp +++ b/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp @@ -477,7 +477,7 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch, problem_.sumB = (reduce_ab == sum_ab::sum_a_row); // Select a kernel from the catalog. - MatchParams match_params[3]; + MatchParams match_params[4]; int npatterns = 1; match_params[0] = MatchParams(hw_, has_systolic, problem_); diff --git a/src/gpu/intel/jit/gemm/generator/pieces/address_setup.cxx b/src/gpu/intel/jit/gemm/generator/pieces/address_setup.cxx index e28c54f12fc..b66bc116a31 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/address_setup.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/address_setup.cxx @@ -366,7 +366,8 @@ void BLASKernelGenerator::setupAddr(Type T, const GRFRange &addr, const BO & if (T.paddedSize() < widthAlign) or_(1, addr[0].ud(2), addr[0].ud(2), widthAlign - 1); } else if (remW.isInvalid() && remH.isInvalid()) - emov(1, addr[0].uq(1), uint64_t(bw * bcount * block.ebytes - 1) | (uint64_t(bh * block.ebytes - 1) << 32), strategy, state); + emov(1, addr[0].uq(1), (((uint64_t)bw * (uint64_t)bcount * block.ebytes - 1) + | ((uint64_t)bh * block.ebytes - 1) << 32), strategy, state); else { if (remW.isValid() && multiX > 1) stub(); remW.isValid() ? addScaled(1, addr[0].ud(2), -1, remW.uw(), T, state, true) diff --git a/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.cpp b/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.cpp index 758cdb85af8..468c51137fe 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.cpp +++ b/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.cpp @@ -1364,7 +1364,7 @@ void CopyPlan::legalizeRegions() /* Check destination stride against execution channels */ int channelSize = 1; - for (auto op: {i.dst, i.src0, i.src1, i.src2}) + for (auto &op: {i.dst, i.src0, i.src1, i.src2}) if (op.kind == op.GRF) channelSize = std::max(channelSize, getBytes(op.type)); diff --git a/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.hpp b/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.hpp index ca657f42558..ad2d951a0ff 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.hpp +++ b/src/gpu/intel/jit/gemm/generator/pieces/copy_plan.hpp @@ -97,7 +97,7 @@ struct CopyTemporary { friend class CopyPlan; - int bytes, align = 0, offset = 0; + int bytes = 0, align = 0, offset = 0; bool flag = false; int16_t cnumMin = 0x7FFF; int16_t cnumMax = -1; diff --git a/src/gpu/intel/jit/gemm/generator/pieces/gemm_setup.cxx b/src/gpu/intel/jit/gemm/generator/pieces/gemm_setup.cxx index eb3ab3377cd..4ff9844dfe3 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/gemm_setup.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/gemm_setup.cxx @@ -1674,7 +1674,7 @@ bool BLASKernelGenerator::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStr i0q = state.ra.alloc_sub(); emad(1, i0q, state.i0, state.lidN, state.ma_slm, strategy, state); } - if (!aoTo2D && state.ka_slm < strategy.unrollKSLM && problem.aqGroupK < strategy.unrollKSLM) { + if ((ao2D || (as2D && !state.lateScale2DA)) && state.ka_slm < strategy.unrollKSLM && problem.aqGroupK < strategy.unrollKSLM) { if (state.lateScale2DA) A_h0s = copySubregister(A_h0q, state); if (A_h0q.isInvalid()) { @@ -1690,7 +1690,7 @@ bool BLASKernelGenerator::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStr j0q = state.ra.alloc_sub(); emad(1, j0q, state.j0, state.lidM, state.nb_slm, strategy, state); } - if (!boTo2D && state.kb_slm < strategy.unrollKSLM && problem.bqGroupK < strategy.unrollKSLM) { + if ((bo2D || (bs2D && !state.lateScale2DB)) && state.kb_slm < strategy.unrollKSLM && problem.bqGroupK < strategy.unrollKSLM) { if (state.lateScale2DB) B_h0s = copySubregister(B_h0q, state); if (B_h0q.isInvalid()) { diff --git a/src/gpu/intel/jit/gemm/generator/pieces/loop_sequencer.cpp b/src/gpu/intel/jit/gemm/generator/pieces/loop_sequencer.cpp index 88b8322938b..8f4dca1bf85 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/loop_sequencer.cpp +++ b/src/gpu/intel/jit/gemm/generator/pieces/loop_sequencer.cpp @@ -43,7 +43,7 @@ void LoopSequencer::schedule(std::vector list) for (auto &entry: list) xlist.push_back(CheckedItem(entry)); - schedule_if(xlist); + schedule_if(std::move(xlist)); } } @@ -56,7 +56,7 @@ void LoopSequencer::schedule_if(std::vector list) { if (!list.empty()) { validate(list); - actions.push_back({list, NeverScheduled}); + actions.push_back({std::move(list), NeverScheduled}); } } @@ -69,7 +69,7 @@ void LoopSequencer::swapLast2() void LoopSequencer::setCallback(CallbackType type, Callback cb) { - callbacks[static_cast(type)] = cb; + callbacks[static_cast(type)] = std::move(cb); } void LoopSequencer::setRemainderHandling(RemainderHandling handling) diff --git a/src/gpu/intel/jit/gemm/generator/pieces/map.hpp b/src/gpu/intel/jit/gemm/generator/pieces/map.hpp index 385eeb4a997..6bb4db9f7ae 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/map.hpp +++ b/src/gpu/intel/jit/gemm/generator/pieces/map.hpp @@ -155,8 +155,8 @@ static inline void map(ngen::HW hw, const GRFMultirange ®s, const std::vector // Variant that allow the type to be specified as a native Type, rather than an nGEN type. template -static inline void map(ngen::HW hw, Type T, Targs... args) { - map(hw, T.ngen(), args...); +static inline void map(ngen::HW hw, Type T, Targs &&...args) { + map(hw, T.ngen(), std::forward(args)...); } static inline bool canDualGRF(ngen::HW hw, ngen::DataType dt, const CommonStrategy &strategy) diff --git a/src/gpu/intel/jit/gemm/generator/pieces/matrix_access.cxx b/src/gpu/intel/jit/gemm/generator/pieces/matrix_access.cxx index c1e26786b7b..91c4ffcd535 100644 --- a/src/gpu/intel/jit/gemm/generator/pieces/matrix_access.cxx +++ b/src/gpu/intel/jit/gemm/generator/pieces/matrix_access.cxx @@ -516,7 +516,7 @@ void BLASKernelGenerator::loadLoadStoreDescriptors(bool load, bool store, Re descLoad.parts.responseLen = 0; int underlyingSIMD = std::max(block.simdSize, maxScatteredSIMD(hw, astrategy) >> 1); - int log2GRFs = ilog2(underlyingSIMD * block.ebytes) - GRF::log2Bytes(hw); + int log2GRFs = ilog2((uint64_t)underlyingSIMD * block.ebytes) - GRF::log2Bytes(hw); int log2Components = int(block.splitComplex); if (channel) mov(1, t2, 0x1000 << log2Components); diff --git a/src/gpu/intel/jit/gemm/generator/strategy_parser.cpp b/src/gpu/intel/jit/gemm/generator/strategy_parser.cpp index 41c668b35fe..9372b42ff7e 100644 --- a/src/gpu/intel/jit/gemm/generator/strategy_parser.cpp +++ b/src/gpu/intel/jit/gemm/generator/strategy_parser.cpp @@ -848,7 +848,7 @@ std::string unparseStrategy(HW hw, const GEMMProblem &problem, const GEMMStrateg } if (strategy.optAlignAB > 0) s << " l" << strategy.optAlignAB; - if (anyOptAlignAB) s << " l2d"; + if (strategy.optAlignAB2D) s << " l2d"; bool nq = false; for (auto &astrategy: {strategy.A, strategy.B, strategy.C, strategy.CO, diff --git a/src/gpu/intel/jit/gemm/selector/db/kernel.db b/src/gpu/intel/jit/gemm/selector/db/kernel.db index c847dbbb1db..78befe63dc1 100644 --- a/src/gpu/intel/jit/gemm/selector/db/kernel.db +++ b/src/gpu/intel/jit/gemm/selector/db/kernel.db @@ -991,18 +991,18 @@ auto _CATALOG_ = kcatalog::toFlatCatalog({ {{'F', "gemm", {"O", "S", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {16, -1, -1}, {1, 1, 1}, ""}, "am8x2+m32@8 aB8x2+m8@8 aB wg 1x4x16 kr kc8 nse li pt sr sb256 bk0 kv afb l2d", {16, (LoopType) 255, 128, {(LoopType) 224, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 32}, {16, 16, 8}, {1, 4, 16}, 1, (WGType) 1, 413, 0, 4096, {4, 4, 4}, {true, true, true}}, {'E', 17, {1.13268e+06, -103657, 246.583, 142575, 3.21126e+06, 0, 1.60336, 0.933443, 0.504957, 0.918636, 0.0712742, 0.0670609, 0.015172, 0.992002, 1.14631, 0.0718492, 1.8422e-11}}}, {{'F', "gemm", {"O", "S", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, 4, -1}, {1, 1, 1}, ""}, "at16x2+m16@32 at16+m32@32 aB wg 16x1x2 kr kc16 nse nmk li pt sr sb256 bk0 sm grf256 kv afb l4 l2d", {16, (LoopType) 255, 256, {(LoopType) 225, (LoopType) 255, (LoopType) 2}, {262144, 65536, 16777216}, {262144, 65536, 32}, {16, 4, 16}, {16, 1, 2}, 1, (WGType) 1, 413, 0, 4096, {4, 4, 4}, {true, true, true}}, {'E', 17, {1.1734e+06, -264907, -109870, 485064, 2.21266e+06, 0, 0.856653, 15.807, 1.98085, 3.89882, 0.125049, 0.0139237, 0.143865, 1, 1.34573, 0.978713, 4.7619e-12}}}, {{'F', "gemm", {"O", "S", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, 16, -1}, {1, 1, 1}, ""}, "at8x2+m16@24 at8x2+m32@8 aB wg 16x1x4 kr kc8 nse nmk li pt sr sb256 bk0 sm sn kv afb l2d", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 32}, {16, 16, 8}, {16, 1, 4}, 1, (WGType) 1, 413, 0, 16384, {4, 4, 4}, {true, true, true}}, {'E', 17, {1.18993e+06, -230103, -26635.9, 388995, 2.2528e+06, 0, 0.900793, 5.78162, 0.552809, 1.28255, 0.0627307, 0.0602325, 0.0232779, 1, 1.21284, 0.921396, 2.8065e-12}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {8, 8, 1}, "ABIp"}, "av32+m32@48 am32+S32@64 aB wg 4x8 xaf st vav hi pt sb32 bk0 sn grf256 sys sr br kv afb rr", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {1048576, 524288, 16777216}, {1048576, 524288, 32}, {32, 16, 32}, {4, 8, 1}, 1, (WGType) 1, 441, 0, 0, {8, 8, 4}, {true, true, true}}, {'E', 17, {869760, 741196, 0, 0, 8.192e+06, 1.05431e+07, 0.731287, 0.777012, 0.881104, 1.51408, 0.00403024, 0.00403024, 0, 0.998184, 1.76752, 1.28618, 2.08947e-12}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ip"}, "aB32 aB32 aB wg 8x4 cab3 ks32 af vav hi pt sr br bk0 sn nb 8x4 dm grf256 sys kv afb l4", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 786432, 16777216}, {524288, 786432, 32}, {16, 32, 16}, {8, 4, 1}, 1, (WGType) 1, 441, 86016, 0, {2, 2, 4}, {true, true, true}}, {'E', 17, {1.07153e+06, 922406, 0, 0, 5.48536e+06, 9.18323e+06, 0.85293, 1.19559, 1.04485, 1.64281, 0.00471518, 0.00471518, 0, 0.961495, 1.72318, 1.24713, 3.69593e-12}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {8, 8, 1}, "ABI"}, "av16 am16+m16@16 aB wg 2x4x4 kr ca3x2 ks16 af vav hi pt sr br bk0 sn grf256 kv afb sys", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {524288, 262144, 16777216}, {524288, 262144, 32}, {32, 16, 16}, {2, 4, 4}, 1, (WGType) 1, 445, 6144, 16384, {8, 8, 4}, {true, true, true}}, {'E', 17, {1.27954e+06, -187821, -42333.9, 291644, 3.34234e+06, 2.63782e+06, 0.670967, 0.826166, 0.942564, 1.64083, 0.0148244, 0.00555253, 0.00975056, 0.806514, 1.26716, 0.788997, 1.48059e-11}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {8, 8, 1}, "ABI"}, "av32+m32@48 av32 aB wg 8x4 cb4 ks32 xaf st vav hi pt sr br bk0 grf256 sys rr kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 786432, 16777216}, {524288, 786432, 32}, {16, 48, 32}, {8, 4, 1}, 1, (WGType) 1, 441, 49152, 0, {8, 8, 4}, {true, true, true}}, {'E', 17, {1.00046e+06, 649003, 0, 0, 5.58776e+06, 8.89651e+06, 0.806799, 1.52159, 1.05017, 1.76588, 0.00548707, 0.00548707, 0, 0.843701, 1.54307, 1.23912, 1.78553e-12}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {8, 8, 1}, "ABIqp"}, "av16+m32@32 av16x2 aB wg 4x8 cb4x2 ks32 xaf vav hi pt sr br bk0 grf256 sys rr kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {1048576, 524288, 16777216}, {1048576, 524288, 32}, {32, 16, 32}, {4, 8, 1}, 1, (WGType) 1, 441, 65536, 0, {8, 8, 4}, {true, true, true}}, {'E', 17, {1.03762e+06, 706048, 0, 0, 6.93043e+06, 1.10019e+07, 0.892859, 1.0972, 0.98165, 1.70677, 0.00434444, 0.00434444, 0, 0.705466, 1.6217, 1.19447, 5.03184e-12}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+B8@16 aB16+B16@16 aB wg 4x4 vav hi pt sb256 bk0 grf256 sr", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 262144, 16777216}, {524288, 262144, 16777216}, {16, 16, 16}, {4, 4, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'E', 17, {1.10018e+06, 251905, 0, 0, 0, 0, 1.56408, 2.85947, 0.648851, 1.37611, 0.0629702, 0.000146865, 0.0632313, 0.517444, 1.16754, -0.0884205, 2.14696e-11}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 8, 1}, "ABIpq"}, "at16x2+m64@48 am32+m32@64 aB wg 8x4 xaf vav hi pt sr br sb64 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 1048576, 16777216}, {524288, 1048576, 32}, {32, 16, 32}, {8, 4, 1}, 1, (WGType) 1, 441, 0, 0, {4, 8, 4}, {true, true, true}}, {'E', 17, {877896, 641311, 0, 0, 7.70867e+06, 1.03055e+07, 0.79175, 0.746432, 0.882948, 1.4915, 0.00422556, 0.00422556, 0, 0.975411, 1.67727, 1.23964, 3.79962e-12}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 8, 1}, "ABIqp"}, "at16+m64@48 am32+m32@56 aB wg 4x8 xaf st vav hi pt sr br sb64 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {1048576, 655360, 16777216}, {1048576, 655360, 32}, {16, 40, 32}, {4, 8, 1}, 1, (WGType) 1, 441, 0, 0, {4, 8, 4}, {true, true, true}}, {'E', 17, {885722, 718272, 0, 0, 8.48691e+06, 1.26566e+07, 1.01026, 0.782981, 0.918959, 1.54496, 0.00414471, 0.00414471, 0, 1, 2.17926, 1.31409, 2.23082e-12}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {16, -1, -1}, {1, 1, 1}, "I"}, "at16+m32@48 aB64 aB wg 1x4x8 kr af vav li pt sr br sb64 bk0 sm dm grf256 sys kv afb l4 l2d", {16, (LoopType) 255, 256, {(LoopType) 224, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 64}, {16, 16, 64}, {1, 4, 8}, 1, (WGType) 1, 445, 0, 4096, {2, 2, 4}, {true, true, true}}, {'E', 17, {1.15217e+06, -88485.5, -9852.59, 144015, 2.69517e+06, 1.69656e+06, 1.08886, 0.414631, 0.587045, 1.19593, 0.0202405, 0.0220064, 0.0155006, 1, 1.0809, 0.644827, 3.81346e-12}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+S1,16@24 aS16+S32@16 aB wg 2x2x8 kr vav hi pt sr sb256 bk0 sm sn grf256 kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {524288, 262144, 16777216}, {8192, 8192, 0}, {16, 16, 16}, {2, 2, 8}, 1, (WGType) 1, 413, 0, 8192, {1, 1, 4}, {true, true, true}}, {'E', 17, {1.2289e+06, -127728, -18531.3, 192240, 3.35053e+06, 0, 0.932716, 1.33521, 0.665104, 1.39923, 0.0628179, 0.0675437, 0.0114361, 0.999809, 1.27564, 0.821381, 3.76564e-11}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 8, 1}, "ABI"}, "at32+m32@48 am16x2+m64@32 aB wg 4x4x2 kr xaf st vav hi pt sr br sb64 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {524288, 524288, 16777216}, {524288, 524288, 32}, {16, 32, 32}, {4, 4, 2}, 1, (WGType) 1, 445, 0, 65536, {4, 8, 4}, {true, true, true}}, {'E', 17, {1.03742e+06, -452509, 25423.9, 695882, 3.92397e+06, 3.94035e+06, 0.820319, 0.785135, 0.84902, 1.56923, 0.00809246, 0.00116078, 0.00745441, 0.526504, 1.60176, 1.07294, 4.15601e-12}}}, -{{'F', "gemm", {"Q", "Q", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aS8+S1,16@16 aB8+B8@16 aU vav wg 8x4 bo pt sm sb256 grf256 bk0 sr", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 1048576, 16777216}, {8192, 8192, 16777216}, {16, 16, 8}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'E', 17, {875898, 864492, 0, 0, 0, 0, 2.55527, 2.12966, 0.857629, 1.85569, 0.0625314, 0.0625314, 0, 1, 1.01096, 1.00724, 3.06523e-14}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {8, 8, 1}, "ABIp"}, "av32+m32@48 am32+S32@64 aB wg 4x8 xaf st vav hi pt sb32 bk0 sn grf256 sys sr br kv afb rr", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 262144, 16777216}, {524288, 262144, 32}, {32, 16, 32}, {4, 8, 1}, 1, (WGType) 1, 441, 0, 0, {8, 8, 4}, {true, true, true}}, {'E', 17, {869760, 741196, 0, 0, 8.192e+06, 1.05431e+07, 0.731287, 0.777012, 0.881104, 1.51408, 0.00403024, 0.00403024, 0, 0.998184, 1.76752, 1.28618, 2.08947e-12}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "Ip"}, "aB32 aB32 aB wg 8x4 cab3 ks32 af vav hi pt sr br bk0 sn nb 8x4 dm grf256 sys kv afb l4", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {262144, 524288, 16777216}, {262144, 524288, 32}, {16, 32, 32}, {8, 4, 1}, 1, (WGType) 1, 441, 49152, 0, {1, 1, 4}, {true, true, true}}, {'E', 17, {1.07153e+06, 922406, 0, 0, 5.48536e+06, 9.18323e+06, 0.85293, 1.19559, 1.04485, 1.64281, 0.00471518, 0.00471518, 0, 0.961495, 1.72318, 1.24713, 3.69593e-12}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {8, 8, 1}, "ABI"}, "av16 am16+m16@16 aB wg 2x4x4 kr ca3x2 ks16 af vav hi pt sr br bk0 sn grf256 kv afb sys", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {524288, 262144, 16777216}, {524288, 262144, 32}, {32, 16, 32}, {2, 4, 4}, 1, (WGType) 1, 445, 6144, 16384, {8, 8, 4}, {true, true, true}}, {'E', 17, {1.27954e+06, -187821, -42333.9, 291644, 3.34234e+06, 2.63782e+06, 0.670967, 0.826166, 0.942564, 1.64083, 0.0148244, 0.00555253, 0.00975056, 0.806514, 1.26716, 0.788997, 1.48059e-11}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {8, 8, 1}, "ABI"}, "av32+m32@48 av32 aB wg 8x4 cb4 ks32 xaf st vav hi pt sr br bk0 grf256 sys rr kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {262144, 786432, 16777216}, {262144, 786432, 32}, {16, 48, 32}, {8, 4, 1}, 1, (WGType) 1, 441, 49152, 0, {8, 8, 4}, {true, true, true}}, {'E', 17, {1.00046e+06, 649003, 0, 0, 5.58776e+06, 8.89651e+06, 0.806799, 1.52159, 1.05017, 1.76588, 0.00548707, 0.00548707, 0, 0.843701, 1.54307, 1.23912, 1.78553e-12}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {8, 8, 1}, "ABIqp"}, "av16+m32@32 av16x2 aB wg 4x8 cb4x2 ks32 xaf vav hi pt sr br bk0 grf256 sys rr kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 262144, 16777216}, {524288, 262144, 32}, {32, 16, 32}, {4, 8, 1}, 1, (WGType) 1, 441, 32768, 0, {8, 8, 4}, {true, true, true}}, {'E', 17, {1.03762e+06, 706048, 0, 0, 6.93043e+06, 1.10019e+07, 0.892859, 1.0972, 0.98165, 1.70677, 0.00434444, 0.00434444, 0, 0.705466, 1.6217, 1.19447, 5.03184e-12}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+B8@16 aB16+B16@16 aB wg 4x4 vav hi pt sb256 bk0 grf256 sr", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {262144, 262144, 16777216}, {262144, 262144, 16777216}, {16, 16, 16}, {4, 4, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'E', 17, {1.10018e+06, 251905, 0, 0, 0, 0, 1.56408, 2.85947, 0.648851, 1.37611, 0.0629702, 0.000146865, 0.0632313, 0.517444, 1.16754, -0.0884205, 2.14696e-11}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 8, 1}, "ABIpq"}, "at16x2+m64@48 am32+m32@64 aB wg 8x4 xaf vav hi pt sr br sb64 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 262144, 16777216}, {524288, 262144, 32}, {32, 16, 32}, {8, 4, 1}, 1, (WGType) 1, 441, 0, 0, {4, 8, 4}, {true, true, true}}, {'E', 17, {877896, 641311, 0, 0, 7.70867e+06, 1.03055e+07, 0.79175, 0.746432, 0.882948, 1.4915, 0.00422556, 0.00422556, 0, 0.975411, 1.67727, 1.23964, 3.79962e-12}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 8, 1}, "ABIqp"}, "at16+m64@48 am32+m32@56 aB wg 4x8 xaf st vav hi pt sr br sb64 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {262144, 655360, 16777216}, {262144, 655360, 32}, {16, 40, 32}, {4, 8, 1}, 1, (WGType) 1, 441, 0, 0, {4, 8, 4}, {true, true, true}}, {'E', 17, {885722, 718272, 0, 0, 8.48691e+06, 1.26566e+07, 1.01026, 0.782981, 0.918959, 1.54496, 0.00414471, 0.00414471, 0, 1, 2.17926, 1.31409, 2.23082e-12}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {16, -1, -1}, {1, 1, 1}, "I"}, "at16+m32@48 aB64 aB wg 1x4x8 kr af vav li pt sr br sb64 bk0 sm dm grf256 sys kv afb l4 l2d", {16, (LoopType) 255, 256, {(LoopType) 224, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 64}, {16, 16, 64}, {1, 4, 8}, 1, (WGType) 1, 445, 0, 4096, {1, 1, 4}, {true, true, true}}, {'E', 17, {1.15217e+06, -88485.5, -9852.59, 144015, 2.69517e+06, 1.69656e+06, 1.08886, 0.414631, 0.587045, 1.19593, 0.0202405, 0.0220064, 0.0155006, 1, 1.0809, 0.644827, 3.81346e-12}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+S1,16@24 aS16+S32@16 aB wg 2x2x8 kr vav hi pt sr sb256 bk0 sm sn grf256 kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 32}, {16, 16, 16}, {2, 2, 8}, 1, (WGType) 1, 445, 0, 4096, {1, 1, 4}, {true, true, true}}, {'E', 17, {1.2289e+06, -127728, -18531.3, 192240, 3.35053e+06, 0, 0.932716, 1.33521, 0.665104, 1.39923, 0.0628179, 0.0675437, 0.0114361, 0.999809, 1.27564, 0.821381, 3.76564e-11}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 8, 1}, "ABI"}, "at32+m32@48 am16x2+m64@32 aB wg 4x4x2 kr xaf st vav hi pt sr br sb64 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {262144, 524288, 16777216}, {262144, 524288, 32}, {16, 32, 32}, {4, 4, 2}, 1, (WGType) 1, 445, 0, 32768, {4, 8, 4}, {true, true, true}}, {'E', 17, {1.03742e+06, -452509, 25423.9, 695882, 3.92397e+06, 3.94035e+06, 0.820319, 0.785135, 0.84902, 1.56923, 0.00809246, 0.00116078, 0.00745441, 0.526504, 1.60176, 1.07294, 4.15601e-12}}}, +{{'F', "gemm", {"Q", "Q", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aS8+S1,16@16 aB8+B8@16 aU vav wg 8x4 bo pt sm sb256 grf256 bk0 sr", {16, (LoopType) 255, 256, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {262144, 262144, 16777216}, {262144, 262144, 16777216}, {16, 16, 8}, {8, 4, 1}, 1, (WGType) 1, 257, 0, 0, {1, 1, 4}, {true, true, true}}, {'E', 17, {875898, 864492, 0, 0, 0, 0, 2.55527, 2.12966, 0.857629, 1.85569, 0.0625314, 0.0625314, 0, 1, 1.01096, 1.00724, 3.06523e-14}}}, {{'F', "gemm", {"S", "F", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+B8@16 aB8+m32@24 aB wg 1x4 kc8 nse hi pt sr sb256 bk0 sn grf256 l4", {16, (LoopType) 255, 128, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 131072, 16777216}, {524288, 131072, 16777216}, {1, 1, 1}, {1, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'E', 17, {1.16538e+06, 40635.2, 0, 0, 0, 0, 1.30731, 1.53858, 0.584971, 1.42067, 0.0634061, 0.0581975, 0.0161667, 1, 1.44276, 1.00478, 2.34818e-11}}}, {{'F', "gemm", {"S", "O", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+B8@16 aB8+m32@24 aB wg 1x4 kc8 nse hi pt sr sb256 bk0 sn grf256 l4", {16, (LoopType) 255, 128, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {524288, 131072, 16777216}, {524288, 131072, 16777216}, {1, 1, 1}, {1, 4, 1}, 1, (WGType) 1, 257, 0, 0, {4, 4, 4}, {true, true, true}}, {'E', 17, {1.16538e+06, 40635.2, 0, 0, 0, 0, 1.30731, 1.53858, 0.584971, 1.42067, 0.0634061, 0.0581975, 0.0161667, 1, 1.44276, 1.00478, 2.34818e-11}}}, {{'F', "gemm", {"S", "S", "S"}, {"A", "B", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "aB8+B8@8 aB8+B8@8 aB nse wg 4x8 bo pt sb256 kc8 bk0 sr", {16, (LoopType) 255, 128, {(LoopType) 192, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {8192, 8192, 16777216}, {32, 32, 8}, {4, 8, 1}, 1, (WGType) 1, 256, 0, 0, {128, 128, 4}, {true, true, true}}, {'W', 1, {1024}}}, diff --git a/src/gpu/intel/jit/ir/blocking.cpp b/src/gpu/intel/jit/ir/blocking.cpp index 1c7049f21fd..69dafd55982 100644 --- a/src/gpu/intel/jit/ir/blocking.cpp +++ b/src/gpu/intel/jit/ir/blocking.cpp @@ -369,34 +369,36 @@ const tiler_params_t &tiler_params() { return params; } -params_distance_t::params_distance_t( - const std::vector ¶ms_vec, - const std::function &convert) { - indexed_tile_t iter; - indexed_tile_t tg; - indexed_tile_t loop; - for (auto &p : params_vec) { - auto &b = p.blocking(); - iter.add(convert(b.iter())); - tg.add(convert(b.thread_group())); - loop.add(convert(b.loop())); +tile_to_vec_t::tile_to_vec_t(const std::vector> &tiles, + const std::vector &_ids) { + if (tiles.empty()) return; + int ntiles = (int)tiles.size(); + int nsubtiles = (int)tiles[0].size(); + std::vector indexed_tiles(nsubtiles); + std::vector ids = _ids; + if (ids.empty()) { + ids.resize(ntiles); + std::iota(ids.begin(), ids.end(), 0); } - iter.finalize(); - tg.finalize(); - loop.finalize(); - - std::vector> ret; - for (auto &p : params_vec) { - auto &b = p.blocking(); - auto v0 = iter.to_index(convert(b.iter())); - auto v1 = tg.to_index(convert(b.thread_group())); - auto v2 = loop.to_index(convert(b.loop())); + ir_assert(ids.size() == tiles.size()); + int max_id = 0; + for (int i = 0; i < ntiles; i++) { + for (int j = 0; j < nsubtiles; j++) { + indexed_tiles[j].add(tiles[i][j]); + } + max_id = std::max(max_id, ids[i]); + } + for (auto &it : indexed_tiles) + it.finalize(); + + vecs_.resize(max_id + 1); + for (int i = 0; i < ntiles; i++) { std::vector v; - v.insert(v.end(), v0.begin(), v0.end()); - v.insert(v.end(), v1.begin(), v1.end()); - v.insert(v.end(), v2.begin(), v2.end()); - if (p.id() >= int(dists_.size())) dists_.resize(p.id() + 1); - dists_[p.id()] = v; + for (int j = 0; j < nsubtiles; j++) { + auto vi = indexed_tiles[j].to_index(tiles[i][j]); + v.insert(v.end(), vi.begin(), vi.end()); + } + vecs_[ids[i]] = std::move(v); } } diff --git a/src/gpu/intel/jit/ir/blocking.hpp b/src/gpu/intel/jit/ir/blocking.hpp index 75e9ab46936..ae6c02a1cfe 100644 --- a/src/gpu/intel/jit/ir/blocking.hpp +++ b/src/gpu/intel/jit/ir/blocking.hpp @@ -109,7 +109,7 @@ class blocking_t { // Returns the ratio of all operations (with padding) to "useful" operations float get_efficiency(const prb_tile_t &shape) const { float ret = 1; - for (auto d : shape) { + for (auto &d : shape) { int loop = loop_.get(d, 1); int tg = thread_group_.get(d, 1); int iter = iter_.get(d, 1); @@ -633,7 +633,8 @@ struct tiler_params_t { const tiler_params_t &tiler_params(); -// Helper class to compute the distance between two blocking schemes. +// Helper class to compute the distance between tiles with sizes. +// // During initialization the number of blocking dimensions might be // reduced for simplicity - e.g. for convolutions by converting them // to BMNK values typical for GEMM - then these dims are converted to @@ -651,19 +652,19 @@ const tiler_params_t &tiler_params(); // L1(B1, B3) = L1(B3, B1) = 1 + 2 + 0 = 3 // L1(B2, B3) = L1(B3, B2) = 2 + 1 + 0 = 3 // L1(B1, B1) = L1(B2, B2) = L1(B3, B3) = 0 -class params_distance_t { +class tile_to_vec_t { public: - params_distance_t() = default; - params_distance_t(const std::vector ¶ms_vec, - const std::function &convert); + tile_to_vec_t() = default; + tile_to_vec_t(const std::vector> &tiles, + const std::vector &ids = {}); float dist(int id0, int id1) const { - auto &d0 = dists_[id0]; - auto &d1 = dists_[id1]; + auto &v0 = vecs_[id0]; + auto &v1 = vecs_[id1]; float ret = 0; // Use L1 distance between coordinates. - for (int i = 0; i < (int)d0.size(); i++) { - ret += std::abs(d0[i] - d1[i]); + for (int i = 0; i < (int)v0.size(); i++) { + ret += std::abs(v0[i] - v1[i]); } return ret; } @@ -707,7 +708,7 @@ class params_distance_t { void add(prb_dim_t d, int value) { dim_mappers_[d.id()].add(value); } void add(const prb_tile_t &t) { - for (auto d : t) { + for (auto &d : t) { add(d, t[d]); } } @@ -733,7 +734,7 @@ class params_distance_t { std::array dim_mappers_; }; - std::vector> dists_; + std::vector> vecs_; }; // Helper class to track performance data collected during tuning. diff --git a/src/gpu/intel/jit/ir/core.cpp b/src/gpu/intel/jit/ir/core.cpp index d6be06b1d40..3de929cc762 100644 --- a/src/gpu/intel/jit/ir/core.cpp +++ b/src/gpu/intel/jit/ir/core.cpp @@ -360,7 +360,7 @@ DEFINE_BINARY_OPERATOR(&, op_kind_t::_and) #define DEFINE_BINARY_ASSIGN_OPERATOR(op) \ expr_t &expr_t::operator op##=(const expr_t &rhs) { \ auto tmp = (*this)op rhs; \ - *this = tmp; \ + *this = std::move(tmp); \ return *this; \ } diff --git a/src/gpu/intel/jit/ir/core.hpp b/src/gpu/intel/jit/ir/core.hpp index 4f1332dadb3..bcbf237c506 100644 --- a/src/gpu/intel/jit/ir/core.hpp +++ b/src/gpu/intel/jit/ir/core.hpp @@ -617,8 +617,7 @@ HANDLE_TRAVERSE_TARGETS() // the reference counter stored inside the object. class object_impl_t { public: - object_impl_t(type_info_t type_info) - : ref_count_(), type_info_(type_info) {}; + object_impl_t(type_info_t type_info) : type_info_(type_info) {}; object_impl_t(const object_impl_t &) = delete; @@ -713,6 +712,7 @@ class object_t { #endif object_t &operator=(const object_t &other) { + if (&other == this) return *this; auto *other_impl = other.impl(); increment(other_impl); decrement_and_maybe_destroy(impl_); @@ -1878,7 +1878,7 @@ class stmt_t : public object_t { stmt_t() = default; stmt_t(const object_t &obj) : object_t(obj) {} - stmt_t(object_t &&obj) : object_t(obj) {} + stmt_t(object_t &&obj) : object_t(std::move(obj)) {} stmt_t &operator=(const object_t &obj) { object_t::operator=(obj); return *this; diff --git a/src/gpu/intel/jit/ir/epilogue.cpp b/src/gpu/intel/jit/ir/epilogue.cpp index f798d714b04..7eb9dd4d0bb 100644 --- a/src/gpu/intel/jit/ir/epilogue.cpp +++ b/src/gpu/intel/jit/ir/epilogue.cpp @@ -79,6 +79,7 @@ class zero_pad_builder_t { const layout_t &layout) const { if (idx == layout.ndims()) { std::vector vargs; + vargs.reserve(layout.ndims()); for (int i = 0; i < layout.ndims(); i++) vargs.push_back(view.vstart(i) + args[i]); expr_t mask = full_mem_view_.vmask(vargs); @@ -295,8 +296,8 @@ class post_op_tensor_t { reg_layout_, f32_layout, reg_buf_, f32_buf); // Assign new f32 layout and buffer. - reg_layout_ = f32_layout; - reg_buf_ = f32_buf; + reg_layout_ = std::move(f32_layout); + reg_buf_ = std::move(f32_buf); return ret; } @@ -339,7 +340,7 @@ class post_op_tensor_t { stmt = stmt.append( create_reduce_stmt(reg_layout_, reduced_layout, reg_buf_, reg_buf_, tensor_t(), mask(), /*drop_dims=*/false)); - reg_layout_ = reduced_layout; + reg_layout_ = std::move(reduced_layout); } return stmt; diff --git a/src/gpu/intel/jit/ir/gemm_schedule.hpp b/src/gpu/intel/jit/ir/gemm_schedule.hpp index 72f5e218a5f..b961b8c3db0 100644 --- a/src/gpu/intel/jit/ir/gemm_schedule.hpp +++ b/src/gpu/intel/jit/ir/gemm_schedule.hpp @@ -345,7 +345,7 @@ class loop_t { auto e = fused_loop.expand_var( all_loops, skip_fused, filter_kind) / denom; - return (i == 0 ? e : e % bound); + return (i == 0 ? std::move(e) : e % bound); } denom *= bound; } diff --git a/src/gpu/intel/jit/ir/ir.cpp b/src/gpu/intel/jit/ir/ir.cpp index 54be3ed71db..20ef75168be 100644 --- a/src/gpu/intel/jit/ir/ir.cpp +++ b/src/gpu/intel/jit/ir/ir.cpp @@ -207,7 +207,6 @@ class ir_printer_t : public ir_visitor_t { remove_indent(); print_indent(); out_ << "}\n"; - return; } void _visit(const stmt_seq_t &obj) override { @@ -229,7 +228,6 @@ class ir_printer_t : public ir_visitor_t { void _visit(const ternary_op_t &obj) override { out_ << to_string(obj.op_kind) << "(" << obj.a << ", " << obj.b << ", " << obj.c << ")"; - return; } void _visit(const unary_op_t &obj) override { @@ -364,7 +362,7 @@ class stmt_flattener_t : public ir_visitor_t { size_t old_size = stmts.size(); \ ir_visitor_t::_visit(obj); \ if (stmts.size() > old_size) return; \ - if (obj.is_stmt()) stmts.push_back(obj); \ + if (obj.is_stmt()) stmts.emplace_back(obj); \ } HANDLE_ALL_IR_OBJECTS() @@ -592,7 +590,7 @@ std::vector find_stmt_groups( auto groups = find_objects(root); std::vector ret; for (auto &g : groups) { - if (g.as().label == label) ret.push_back(g); + if (g.as().label == label) ret.emplace_back(g); } return ret; } @@ -736,7 +734,7 @@ bool relation_t::implies(const relation_t &other) const { } relation_t relation_t::transform( - const linear_transform_t &t, const expr_t &new_var) { + const linear_transform_t &t, const expr_t &new_var) const { ir_assert(t.a == 1) << "Not implemented."; return relation_t(binary_op_t::make(op_kind(), new_var, rhs() + t.b)); } @@ -1012,14 +1010,18 @@ bool constraint_set_t::is_single_value(const expr_t &e, expr_t &value) const { case op_kind_t::_gt: { auto cur_lo = (rel.op_kind() == op_kind_t::_ge ? rel.rhs() : rel.rhs() + 1); - if (lo.is_empty() || to_cpp(cur_lo > lo)) { lo = cur_lo; } + if (lo.is_empty() || to_cpp(cur_lo > lo)) { + lo = std::move(cur_lo); + } break; } case op_kind_t::_le: case op_kind_t::_lt: { auto cur_hi = (rel.op_kind() == op_kind_t::_le ? rel.rhs() : rel.rhs() - 1); - if (hi.is_empty() || to_cpp(cur_hi < hi)) { hi = cur_hi; } + if (hi.is_empty() || to_cpp(cur_hi < hi)) { + hi = std::move(cur_hi); + } break; } default: ir_error_not_expected() << rel; @@ -1027,7 +1029,7 @@ bool constraint_set_t::is_single_value(const expr_t &e, expr_t &value) const { if (do_break) break; } bool ret = !lo.is_empty() && lo.is_equal(hi); - if (ret) value = lo; + if (ret) value = std::move(lo); return ret; } diff --git a/src/gpu/intel/jit/ir/ir.hpp b/src/gpu/intel/jit/ir/ir.hpp index ce4073e5190..76accbc7b69 100644 --- a/src/gpu/intel/jit/ir/ir.hpp +++ b/src/gpu/intel/jit/ir/ir.hpp @@ -646,6 +646,7 @@ struct mem_usage_guard_t { } mem_usage_guard_t &operator=(mem_usage_guard_t &&other) { + if (&other == this) return *this; usage = other.usage; peak_usage = other.peak_usage; size = other.size; @@ -696,7 +697,8 @@ class relation_t { bool implies(const relation_t &other) const; // Applies linear transformation to left and right hand sides of the relation. - relation_t transform(const linear_transform_t &t, const expr_t &new_var); + relation_t transform( + const linear_transform_t &t, const expr_t &new_var) const; std::string str() const { std::ostringstream oss; diff --git a/src/gpu/intel/jit/ir/linear_expr.cpp b/src/gpu/intel/jit/ir/linear_expr.cpp index 996465344f0..0607d1a1110 100644 --- a/src/gpu/intel/jit/ir/linear_expr.cpp +++ b/src/gpu/intel/jit/ir/linear_expr.cpp @@ -180,7 +180,7 @@ expr_t linear_normalize_const_factor_out(const expr_t &_e) { } std::vector v_common; - v_common.push_back(const_factor); + v_common.emplace_back(const_factor); for (auto &kv : common) { for (int i = 0; i < kv.second; i++) v_common.push_back(kv.first); @@ -256,7 +256,8 @@ class linear_coef_t { auto lhs = op_combine(op_kind_t::_mul, factors_); auto rhs = op_combine(op_kind_t::_mul, other.factors_); int const_factor = 1; - auto common = find_common_factors({lhs, rhs}, const_factor); + auto common = find_common_factors( + {std::move(lhs), std::move(rhs)}, const_factor); ir_assert(const_factor == 1); factors_.clear(); for (auto &kv : common) { @@ -295,6 +296,7 @@ class linear_coef_t { static std::vector div(const std::vector &v, int factor) { std::vector ret; + ret.reserve(v.size()); for (auto &e : v) ret.push_back(div(e, factor)); return ret; diff --git a/src/gpu/intel/jit/ir/message.cpp b/src/gpu/intel/jit/ir/message.cpp index fc9f21fe6ce..00708d9457c 100644 --- a/src/gpu/intel/jit/ir/message.cpp +++ b/src/gpu/intel/jit/ir/message.cpp @@ -69,7 +69,7 @@ stmt_t send_t::create_offset_store(const expr_t &header_buf, } off += mem_off; } else { - off = mem_off; + off = std::move(mem_off); } off = cast(off, address_type(is_signed_offset, off.type().elems())); return store_t::make(header_sub_buf, 0, off); @@ -280,7 +280,7 @@ class memory_walker_t { ir_assert(block_idx >= 0 && block_idx < int(block_offs_.size())); base = block_offs_[block_idx]; auto prev_base = block_offs_[block_idx == 0 ? 0 : block_idx - 1]; - auto get_const_summand = [&](expr_t expr) -> int64_t { + auto get_const_summand = [&](const expr_t &expr) -> int64_t { if (!expr.type().is_int()) return 0; auto binary_op = expr.as_ptr(); if (binary_op && binary_op->op_kind == op_kind_t::_add @@ -567,7 +567,7 @@ static stmt_t try_promote_to_lsc(const stmt_t &_call) { if (mask.is_empty()) return call; auto new_args = call.args; - send_t::arg_mask(new_args) = mask; + send_t::arg_mask(new_args) = std::move(mask); auto lsc_send = send_t::make(send.hw, send.op, send.address, send.type, send.slots, /*is_lsc=*/true, send.zero_out, send.cache_hint); @@ -999,13 +999,13 @@ stmt_t access_builder_t::create_send_stmt( auto off = mem_walker.get_offset( i * send.type.size(), off_base, off_const); if (off_base0.is_empty()) { - off_base0 = off_base; + off_base0 = std::move(off_base); off_const0 = off_const; } else if (!off_base.is_equal(off_base0)) { is_same_base = false; } off_vec.push_back(off); - off_const_vec.push_back(off_const - off_const0); + off_const_vec.emplace_back(off_const - off_const0); } expr_t off; if (send.slots == 1 || !is_same_base) { diff --git a/src/gpu/intel/jit/ir/message_patterns.hpp b/src/gpu/intel/jit/ir/message_patterns.hpp index 4fc3b2da109..1d8bde660bb 100644 --- a/src/gpu/intel/jit/ir/message_patterns.hpp +++ b/src/gpu/intel/jit/ir/message_patterns.hpp @@ -228,7 +228,7 @@ struct send_hint_t { }; dim_t surface_pitch() const { dim_t val = 0; - for (auto s : strides_) { + for (auto &s : strides_) { if (is_h_dim(s.dim)) { val = s.stride; } } return val * type_size_; @@ -236,7 +236,7 @@ struct send_hint_t { dim_t surface_width() const { dim_t val = 0; - for (auto s : strides_) { + for (auto &s : strides_) { if (is_w_dim(s.dim)) val = hint_[s.dim.id()] * s.stride; } return val * type_size_; diff --git a/src/gpu/intel/jit/ir/problem.hpp b/src/gpu/intel/jit/ir/problem.hpp index 2f56e5c746f..baca1ad0123 100644 --- a/src/gpu/intel/jit/ir/problem.hpp +++ b/src/gpu/intel/jit/ir/problem.hpp @@ -257,7 +257,7 @@ class dim_map_t { std::vector keys() const { std::vector ret; - for (auto key : *this) + for (auto &key : *this) ret.push_back(key); return ret; } @@ -290,7 +290,7 @@ class dim_map_t { std::unordered_map to_map() const { std::unordered_map ret; - for (auto d : (*this)) { + for (auto &d : (*this)) { ret[d.name()] = at(d); } return ret; diff --git a/src/gpu/intel/jit/ir/send_plan.cpp b/src/gpu/intel/jit/ir/send_plan.cpp index fcd252a7f8f..3b92b236438 100644 --- a/src/gpu/intel/jit/ir/send_plan.cpp +++ b/src/gpu/intel/jit/ir/send_plan.cpp @@ -198,8 +198,9 @@ expr_t to_vec(const vec_off_t &off, int elems) { ir_assert(off.size() == elems); if (off.size() == 1) return off[0]; std::vector e_off; - for (auto o : off) - e_off.push_back(o); + e_off.reserve(off.size()); + for (auto &o : off) + e_off.emplace_back(o); return shuffle_t::make(e_off); } @@ -859,9 +860,6 @@ class split_bounds_t { auto tile = layout.split_exact(factor); if (tile.is_empty()) return; - std::vector step = tile.dims(); - std::vector idx(layout.ndims()); - layout.for_each_tile(tile, [&](const std::vector &start) { int off = layout.offset_in_bytes(start); offs_.push_back(off); @@ -977,7 +975,7 @@ struct send_group_t { if (!has_mask(md.tidx())) continue; auto md_mask = md.to_expr(mask_inc.slice(idx) + inc[idx]); if (ret.is_empty()) { - ret = md_mask; + ret = std::move(md_mask); } else { ret &= md_mask; } @@ -1106,7 +1104,7 @@ struct send_group_t { } auto ret = *this; - ret.blocks = new_blocks; + ret.blocks = std::move(new_blocks); return ret; } @@ -2507,8 +2505,8 @@ send_group_t init_scattered(const view_info_t &info, it.next(mask_base, addr_base, it.inner_elems(), inner_slots, slot_size, ret.mask_bits); } - ret.addr_inc = addr_base; - ret.mask_inc = mask_base; + ret.addr_inc = std::move(addr_base); + ret.mask_inc = std::move(mask_base); reg_layout = layout_t(vlayout.type(), vlayout.ndims(), 0, std::vector( blocks.begin(), blocks.begin() + info.outer_idx())); diff --git a/src/gpu/intel/jit/ir/slm_reduce_builder.cpp b/src/gpu/intel/jit/ir/slm_reduce_builder.cpp index 9712df3599d..04feb34a13b 100644 --- a/src/gpu/intel/jit/ir/slm_reduce_builder.cpp +++ b/src/gpu/intel/jit/ir/slm_reduce_builder.cpp @@ -89,7 +89,7 @@ void slm_reduce_builder_t::build() { if (split_grid.dim(i) == full_grid.dim(i)) continue; auto cond = full_grid.idx(i) < split_grid.dim(i); if (reduce_cond_.is_empty()) - reduce_cond_ = cond; + reduce_cond_ = std::move(cond); else reduce_cond_ &= cond; } @@ -120,7 +120,7 @@ void slm_reduce_builder_t::build() { tmp_reg_buf_size_ = std::max(tmp_reg_buf_size_, read.reg_buf_size()); - auto read_layout = read.reg_layout(); + auto &read_layout = read.reg_layout(); load_stmt_ = load_stmt_.append(create_reduce_stmt(read_layout, reg_layout_, tmp_reg_buf_, reg_buf_, tensor_t(), reduction_mask())); diff --git a/src/gpu/intel/jit/ir/tensor.hpp b/src/gpu/intel/jit/ir/tensor.hpp index 8b16e0a25d1..1c0c35d593f 100644 --- a/src/gpu/intel/jit/ir/tensor.hpp +++ b/src/gpu/intel/jit/ir/tensor.hpp @@ -271,6 +271,7 @@ class grid_info_t { private: static std::vector make_idxs(const std::string &prefix, int n) { std::vector ret; + ret.reserve(n); for (int i = 0; i < n; i++) ret.push_back( var_t::make(type_t::s32(), prefix + std::to_string(i))); @@ -497,7 +498,7 @@ class layout_t { const auto other_blocks = other.normalize().blocks(); const auto self_blocks = normalize().blocks(); if (self_blocks.size() > other_blocks.size()) return false; - if (self_blocks.size() == 0) return true; + if (self_blocks.empty()) return true; int i = 0; for (; i < (int)self_blocks.size() - 1; i++) { @@ -575,6 +576,7 @@ class layout_t { // The innermost block (first) has index 0. std::vector> enumerated_blocks() const { std::vector> ret; + ret.reserve(blocks_.size()); for (int i = 0; i < int(blocks_.size()); i++) { ret.emplace_back(i, blocks_[i]); } @@ -805,8 +807,8 @@ class layout_t { auto tile = split_exact(sub_grid); if (tile.is_empty()) continue; if (min_tile.is_empty() || tile.elems() < min_tile.elems()) { - min_tile = tile; - if (out_grid) { *out_grid = sub_grid; } + min_tile = std::move(tile); + if (out_grid) { *out_grid = std::move(sub_grid); } } } return min_tile; @@ -1443,7 +1445,7 @@ class view_t { return tdims_[idx]; } - void set_tdim(int tidx, const expr_t &_texpr, expr_t mask = {}) { + void set_tdim(int tidx, const expr_t &_texpr, const expr_t &mask = {}) { ir_assert(tdims_[tidx].is_empty()); auto texpr = simplify(_texpr); @@ -1457,7 +1459,7 @@ class view_t { << "Tensor dimension must have at least one view dimension " "that maps to it."; } - tdims_[tidx] = tdim; + tdims_[tidx] = std::move(tdim); } void set_vdim( @@ -1736,7 +1738,7 @@ class view_t { auto &vvar = vvars()[vidx]; int vdim = vdims()[vidx]; if (vdim == 1) continue; - auto A = tdim.expr(); + const auto &A = tdim.expr(); auto B = jit::substitute(A, vvar, vvar + 1); auto C = simplify(B - A); if (!is_const(C)) { diff --git a/src/gpu/intel/jit/ngen/ngen_asm.hpp b/src/gpu/intel/jit/ngen/ngen_asm.hpp deleted file mode 100644 index faebf6cfea8..00000000000 --- a/src/gpu/intel/jit/ngen/ngen_asm.hpp +++ /dev/null @@ -1,1787 +0,0 @@ -/******************************************************************************* -* Copyright 2019-2024 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef NGEN_ASM_HPP -#define NGEN_ASM_HPP - -#include "ngen_config.hpp" - -#include -#include -#include -#include - -#define NGEN_ASM -#include "ngen.hpp" - - -namespace NGEN_NAMESPACE { - - -inline void RegData::outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const -{ -#ifdef NGEN_SAFE - if (isInvalid()) throw invalid_object_exception(); -#endif - auto vs = getVS(); - if (detail == PrintDetail::vs_hs) - if (vs > 8 && (getHS() != 0)) - vs = 8; - - if (getNeg()) str << '-'; - if (getAbs()) str << "(abs)"; - - if (isARF()) { - str << getARFType(); - switch (getARFType()) { - case ARFType::null: - case ARFType::sp: - case ARFType::ip: - break; - default: - str << getARFBase(); - } - } else if (isIndirect()) { - str << "r["; - getIndirectReg().outputText(str, PrintDetail::sub_no_type, man); - if (getOffset()) - str << ',' << getOffset(); - str << ']'; - } else - str << 'r' << base; - - if (detail <= PrintDetail::base) return; - - if (!isIndirect() && !isNull()) - str << '.' << getLogicalOffset(); - - if (detail <= PrintDetail::sub_no_type) return; - - if (detail >= PrintDetail::hs && !isNull()) { - str << '<'; - if (detail >= PrintDetail::vs_hs && !isVxIndirect()) - str << vs << ';'; - if (detail == PrintDetail::full) - str << getWidth() << ','; - str << getHS(); - str << '>'; - } - - str << ':' << getType(); -} - -static inline std::ostream& operator<<(std::ostream &str, const RegData &r) -{ - LabelManager man; - r.outputText(str, PrintDetail::full, man); - return str; -} - -inline void Immediate::outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const -{ - uint64_t nbytes = getBytes(getType()); - uint64_t val; - - if (nbytes == 8) - val = payload; - else - val = payload & ((uint64_t(1) << (nbytes * 8)) - 1); - - str << "0x" << std::hex << val << std::dec; - if (!hiddenType && detail >= PrintDetail::sub) - str << ':' << type; -} - -inline void ExtendedReg::outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const -{ -#ifdef NGEN_SAFE - if (isInvalid()) throw invalid_object_exception(); -#endif - - if (base.getNeg()) str << '-'; - if (base.getAbs()) str << "(abs)"; - - str << 'r' << base.getBase() << '.'; - if (mmeNum == 8) - str << "nomme"; - else - str << "mme" << int(mmeNum); - - if (detail >= PrintDetail::sub) - str << ':' << base.getType(); -} - -inline void Align16Operand::outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const -{ -#ifdef NGEN_SAFE - if (isInvalid()) throw invalid_object_exception(); - throw iga_align16_exception(); -#else - str << ""; -#endif -} - -inline void GRFRange::outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const -{ - str << 'r' << int(base) << ':' << int(len); -} - -inline void Label::outputText(std::ostream &str, PrintDetail detail, LabelManager &man) { - str << 'L' << getID(man); -} - -struct NoOperand { - static const bool emptyOp = true; - void fixup(HW hw, int esize, int ewidth, DataType defaultType, int srcN, int arity) const {} - constexpr DataType getType() const { return DataType::invalid; } - constexpr bool isScalar() const { return false; } - - void outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const {} -}; - -struct AsmOperand { - union { - RegData reg; - ExtendedReg ereg; - Immediate imm; - Label label; - GRFRange range; - }; - enum class Type : uint8_t { - none = 0, - reg = 1, - ereg = 2, - imm = 3, - label = 4, - range = 5 - } type; - - AsmOperand() : type{Type::none} {} - AsmOperand(NoOperand) : AsmOperand() {} - AsmOperand(RegData reg_) : reg{reg_}, type{Type::reg} {} - AsmOperand(ExtendedReg ereg_) : ereg{ereg_}, type{Type::ereg} {} - AsmOperand(Immediate imm_) : imm{imm_}, type{Type::imm} {} - AsmOperand(Label label_) : label{label_}, type{Type::label} {} - AsmOperand(GRFRange range_) : range{range_}, type{Type::range} {} - AsmOperand(uint32_t imm_) : imm{imm_}, type{Type::imm} {} - - void outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const { - switch (type) { - case Type::none: break; - case Type::ereg: ereg.outputText(str, detail, man); break; - case Type::reg: reg.outputText(str, detail, man); break; - case Type::imm: imm.outputText(str, detail, man); break; - case Type::label: { - auto clone = label; - clone.outputText(str, detail, man); - break; - } - case Type::range: range.outputText(str, detail, man); break; - } - } -}; - -struct AsmInstruction { - Opcode op; - uint16_t ext; - uint32_t inum; - InstructionModifier mod; - AsmOperand dst, src[5]; - LabelManager *labelManager; - std::string comment; - - AsmInstruction(Opcode op_, uint16_t ext_, uint32_t inum_, InstructionModifier mod_, LabelManager *man, - AsmOperand dst_ = NoOperand(), AsmOperand src0 = NoOperand(), AsmOperand src1 = NoOperand(), - AsmOperand src2 = NoOperand(), AsmOperand src3 = NoOperand(), AsmOperand src4 = NoOperand()) - : op(op_), ext(ext_), inum(inum_), mod(mod_), dst(dst_), src{src0, src1, src2, src3, src4}, labelManager{man}, comment{} {} - - explicit AsmInstruction(uint32_t inum_, const std::string &comment_) - : op(Opcode::illegal), ext(0), inum(inum_), mod{}, dst{}, src{}, labelManager{nullptr}, comment{comment_} {} - inline AsmInstruction(const autoswsb::SyncInsertion &si); - - bool isLabel() const { return (op == Opcode::illegal) && (dst.type == AsmOperand::Type::label); } - bool isComment() const { return (op == Opcode::illegal) && !comment.empty(); } - - // Auto-SWSB interface. - bool autoSWSB() const { return mod.isAutoSWSB(); } - SWSBInfo swsb() const { return mod.getSWSB(); } - void setSWSB(SWSBInfo swsb) { mod.setSWSB(swsb); } - void clearAutoSWSB() { mod.setAutoSWSB(false); } - Opcode opcode() const { return op; } - SyncFunction syncFC() const { return static_cast(ext & 0xF); } - SharedFunction sfid() const { return static_cast(ext & 0xF); } - bool eot() const { return mod.isEOT(); } - bool predicated() const { return !mod.isWrEn() || (mod.getPredCtrl() != PredCtrl::None); } - bool atomic() const { return mod.isAtomic(); } - - inline unsigned dstTypecode() const { return getTypecode(dst); } - inline unsigned src0Typecode() const { return getTypecode(src[0]); } - inline unsigned src1Typecode() const { return getTypecode(src[1]); } - inline autoswsb::DestinationMask destinations(int &jip, int &uip) const; - inline bool getOperandRegion(autoswsb::DependencyRegion ®ion, int opNum) const; - - void shiftJIP(int32_t shift) const {} - void shiftUIP(int32_t shift) const {} - - bool getImm32(uint32_t &imm, int opNum = 0) const { - if (src[opNum].type == AsmOperand::Type::imm) { - imm = uint32_t(static_cast(src[opNum].imm)); - return true; - } else - return false; - } - bool getARFType(ARFType &arfType, int opNum, HW hw) const { - auto &opd = (opNum < 0) ? dst : src[opNum]; - if (opd.type == AsmOperand::Type::reg && opd.reg.isARF()) { - arfType = opd.reg.getARFType(); - return true; - } else - return false; - } - bool getSendDesc(MessageDescriptor &desc) const { return getImm32(desc.all, 3); } - int getFencedepJIP() const { - if (src[0].type == AsmOperand::Type::label) { - auto label = src[0].label; - return labelManager->getTarget(label.getID(*labelManager)) - inum + 1; - } else - return 0; - } - -protected: - static inline unsigned getTypecode(const AsmOperand &op); -}; - -AsmInstruction::AsmInstruction(const autoswsb::SyncInsertion &si) -{ - op = Opcode::sync; - ext = static_cast(si.fc); - mod = InstructionModifier::createMaskCtrl(true); - mod.setSWSB(si.swsb); - dst = NoOperand(); - for (auto n = 0; n < 4; n++) - src[n] = NoOperand(); - if (si.mask) - src[0] = Immediate::ud(si.mask); - else - src[0] = NullRegister(); -} - -unsigned AsmInstruction::getTypecode(const AsmOperand &op) -{ - DataType dt = DataType::invalid; - - switch (op.type) { - case AsmOperand::Type::reg: dt = op.reg.getType(); break; - case AsmOperand::Type::ereg: dt = op.ereg.getType(); break; - default: break; - } - - return getTypecode12(dt); -} - -autoswsb::DestinationMask AsmInstruction::destinations(int &jip, int &uip) const -{ - using namespace autoswsb; - - if (!isBranch(op)) - return eot() ? DestNone : DestNextIP; - - if (src[0].type == AsmOperand::Type::reg) - return DestUnknown; - - DestinationMask mask = DestNextIP; - if (src[0].type == AsmOperand::Type::label) { - auto label = src[0].label; - mask |= DestJIP; - jip = labelManager->getTarget(label.getID(*labelManager)) - inum; - } - - if (src[1].type == AsmOperand::Type::label) { - auto label = src[1].label; - mask |= DestUIP; - uip = labelManager->getTarget(label.getID(*labelManager)) - inum; - } - - if (op == Opcode::jmpi && mod.getPredCtrl() == PredCtrl::None) - mask &= ~DestNextIP; - - return mask; -} - -bool AsmInstruction::getOperandRegion(autoswsb::DependencyRegion ®ion, int opNum) const -{ - using namespace autoswsb; - const AsmOperand &operand = (opNum < 0) ? dst : src[opNum]; - RegData rd; - auto hw = region.hw; - - switch (operand.type) { - case AsmOperand::Type::reg: rd = operand.reg; break; - case AsmOperand::Type::ereg: rd = operand.ereg.getBase(); break; - case AsmOperand::Type::range: region = DependencyRegion(hw, operand.range); return true; - default: return false; - } - - if (rd.isARF() && !autoswsb::trackableARF(rd.getARFType())) - return false; - - if (rd.isIndirect()) - region = DependencyRegion(); - else if (op == Opcode::send || op == Opcode::sendc) { - int len = 0; - if (opNum <= 0) { - if (src[3].type == AsmOperand::Type::imm) { - MessageDescriptor desc; - desc.all = static_cast(src[3].imm); - len = (opNum < 0) ? desc.parts.responseLen : desc.parts.messageLen; - if (len == 31) len++; // 32 GRF responses are encoded as 31. Conservatively use the higher value. - } else - len = -1; - } else if (opNum == 1) { - bool exdescImm = (src[2].type == AsmOperand::Type::imm); - if (exdescImm && (hw >= HW::XeHPG)) - len = ext >> 8; - else if (exdescImm) { - ExtendedMessageDescriptor exdesc; - exdesc.all = static_cast(src[2].imm); - len = exdesc.parts.extMessageLen; - } else - len = -1; - } - if (len == 0) - return false; - else if (len == -1) - region = DependencyRegion(); - else - region = DependencyRegion(hw, GRFRange(rd.getBase(), len)); - } else if (op == Opcode::dpas || op == Opcode::dpasw) { - unsigned sdepth = ext >> 8; - unsigned rcount = ext & 0xFF; - unsigned len; - - switch (opNum) { - case -1: - case 0: len = GRF::bytesToGRFs(hw, rcount * operand.reg.getBytes() * mod.getExecSize()); break; - case 1: len = sdepth; break; - case 2: - if (op == Opcode::dpasw) rcount = (rcount + 1) >> 1; - len = GRF::bytesToGRFs(hw, operand.reg.getByteOffset() + sdepth * rcount * 4); - break; - default: return false; - } - - region = DependencyRegion(hw, GRFRange(operand.reg.getBase(), len)); - } else - region = DependencyRegion(hw, mod.getExecSize(), rd); - - return true; -} - -#if defined(NGEN_GLOBAL_REGS) && !defined(NGEN_GLOBAL_REGS_DEFINED) -#include "ngen_registers.hpp" -#endif - -class AsmCodeGenerator { -private: -#include "ngen_compiler_fix.hpp" -public: - explicit AsmCodeGenerator(Product product_) : hardware(getCore(product_.family)), product(product_), defaultOutput{nullptr}, - sync{this}, load{this}, store{this}, atomic{this} - { - isGen12 = (hardware >= HW::Gen12LP); - _workaround_(); - streamStack.push_back(new InstructionStream()); - } - - explicit AsmCodeGenerator(HW hardware_, int stepping_ = 0) : AsmCodeGenerator({genericProductFamily(hardware_), 0}) {} - - AsmCodeGenerator(HW hardware_, std::ostream &defaultOutput_, int stepping_ = 0) : AsmCodeGenerator(hardware_, stepping_) { - defaultOutput = &defaultOutput_; - } - ~AsmCodeGenerator() noexcept(false) { - if (defaultOutput != nullptr) - getCode(*defaultOutput); - for (auto &s : streamStack) - delete s; - } - inline void getCode(std::ostream &out); - void enableLineNumbers(bool enable = true) { lineNumbers = enable; } - - Product getProduct() const { return product; } - ProductFamily getProductFamily() const { return product.family; } - int getStepping() const { return product.stepping; } - - void setProduct(Product product_) { product = product_; } - void setProductFamily(ProductFamily family_) { product.family = family_; } - void setStepping(int stepping_) { product.stepping = stepping_; } - -protected: - struct InstructionStream { - std::vector buffer; - std::vector labels; - - template - AsmInstruction &append(Opcode op, uint16_t ext, Remaining&&... args) { - buffer.emplace_back(op, ext, 0, std::forward(args)...); - return buffer.back(); - } - - void appendComment(const std::string &str) { buffer.emplace_back(0, str); } - - void mark(Label &label, LabelManager &man) { - uint32_t id = label.getID(man); - - man.setTarget(id, buffer.size()); - labels.push_back(id); - buffer.emplace_back(Opcode::illegal, 0, 0, InstructionModifier(), &man, label); - } - - void append(InstructionStream &other, LabelManager &man) { - for (uint32_t id : other.labels) - man.offsetTarget(id, buffer.size()); - - buffer.insert(buffer.end(), other.buffer.begin(), other.buffer.end()); - labels.insert(labels.end(), other.labels.begin(), other.labels.end()); - } - }; - - HW hardware; - Product product; - bool isGen12; - int declaredGRFs = 128; - std::ostream *defaultOutput; - bool lineNumbers = false; - - Label _labelLocalIDsLoaded; - Label _labelArgsLoaded; - Label _lastFenceLabel; - RegData _lastFenceDst; - -private: - InstructionModifier defaultModifier; - LabelManager labelManager; - std::vector streamStack; - - inline void unsupported(); - - // Output functions. - template - inline void opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1, S2 src2, uint16_t ext); - - template void opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1, S2 src2) { - opX(op, defaultType, mod, dst, src0, src1, src2, 0); - } - template void opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1) { - opX(op, defaultType, mod, dst, src0, src1, NoOperand()); - } - template void opX(Opcode op, const InstructionModifier &mod, D dst, S0 src0, S1 src1) { - opX(op, DataType::invalid, mod, dst, src0, src1); - } - template void opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0) { - opX(op, defaultType, mod, dst, src0, NoOperand()); - } - template void opX(Opcode op, const InstructionModifier &mod, D dst, S0 src0) { - opX(op, DataType::invalid, mod, dst, src0); - } - template void opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst) { - opX(op, defaultType, mod, dst, NoOperand()); - } - template void opX(Opcode op, const InstructionModifier &mod, D dst) { - opX(op, DataType::invalid, mod, dst); - } - void opX(Opcode op) { - opX(op, InstructionModifier(), NoOperand()); - } - void opX(Opcode op, const InstructionModifier &mod, Label &jip) { - (void) jip.getID(labelManager); - opX(op, DataType::invalid, mod, NoOperand(), jip); - } - void opX(Opcode op, const InstructionModifier &mod, Label &jip, Label &uip) { - (void) jip.getID(labelManager); - (void) uip.getID(labelManager); - opX(op, DataType::invalid, mod, NoOperand(), jip, uip, NoOperand()); - } - - template - void opSend(Opcode op, const InstructionModifier &mod, SharedFunction sf, RegData dst, RegData src0, S1 src1, ED exdesc, D desc) { - if (src1.emptyOp && (isGen12 || op == Opcode::sends || op == Opcode::sendsc)) { - opSend(op, mod, sf, dst, src0, null, exdesc, desc); - return; - } - - auto &i = streamStack.back()->append(op, static_cast(sf), mod | defaultModifier, &labelManager, dst, src0, src1, exdesc, desc); - if (i.src[2].type == AsmOperand::Type::imm && i.src[1].type != AsmOperand::Type::none) { - uint32_t exdesc = static_cast(i.src[2].imm); - if (isGen12) { - if (hardware >= HW::XeHPG) { - i.ext |= 0x80 | (((exdesc >> 6) & 0x1F) << 8); - i.src[2].imm = uint32_t(exdesc & ~0x7EF); - } else - i.src[2].imm = uint32_t(exdesc & ~0x2F); - } else - i.src[2].imm = uint32_t(exdesc | static_cast(sf)); - } - } - void opDpas(Opcode op, DataType defaultType, const InstructionModifier &mod, int sdepth, int rcount, RegData dst, RegData src0, RegData src1, RegData src2) { - dst.fixup(hardware, 1, 0, defaultType, -1, 3); - src0.fixup(hardware, 1, 0, defaultType, 0, 3); - src1.fixup(hardware, 1, 0, defaultType, 1, 3); - src2.fixup(hardware, 1, 0, defaultType, 2, 3); - (void) streamStack.back()->append(op, (sdepth << 8) | rcount, mod | defaultModifier, &labelManager, dst, src0, src1, src2); - } - template void opCall(Opcode op, const InstructionModifier &mod, D dst, S0 src0) { - (void) streamStack.back()->append(op, 0, mod | defaultModifier | NoMask, &labelManager, dst, src0); - } - template void opJmpi(Opcode op, const InstructionModifier &mod, S1 src1) { - (void) streamStack.back()->append(op, 0, mod | defaultModifier | NoMask, &labelManager, NoOperand(), src1); - } - template void opSync(Opcode op, SyncFunction fc, const InstructionModifier &mod, S0 src0) { - (void) streamStack.back()->append(op, static_cast(fc), mod | defaultModifier, &labelManager, NoOperand(), src0); - } - - inline void finalize(); - - enum class ModPlacementType {Pre, Mid, Post}; - inline void outX(std::ostream &out, const AsmInstruction &i, int lineNo); - inline void outExt(std::ostream &out, const AsmInstruction &i); - inline void outMods(std::ostream &out, const InstructionModifier &mod, Opcode op, ModPlacementType location); - inline void outSync(std::ostream &out, const autoswsb::SyncInsertion &si); - -protected: - // Configuration. - void setDefaultNoMask(bool def = true) { defaultModifier.setWrEn(def); } - void setDefaultAutoSWSB(bool def = true) { defaultModifier.setAutoSWSB(def); } - bool getDefaultNoMask() const { return defaultModifier.isWrEn(); } - bool getDefaultAutoSWSB() const { return defaultModifier.isAutoSWSB(); } - - // Stream handling. - void pushStream() { pushStream(new InstructionStream()); } - void pushStream(InstructionStream &s) { pushStream(&s); } - void pushStream(InstructionStream *s) { streamStack.push_back(s); } - - inline InstructionStream *popStream(); - - void appendStream(InstructionStream *s) { appendStream(*s); } - void appendStream(InstructionStream &s) { streamStack.back()->append(s, labelManager); } - void appendCurrentStream() { InstructionStream *s = popStream(); appendStream(s); delete s; } - - void discardStream() { delete popStream(); } - - void comment(const std::string &str) { streamStack.back()->appendComment(str); } - - void requireGRF(int grfs) { declaredGRFs = grfs; } - - // Instructions. - template - void add(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::add, getDataType
(), mod, dst, src0, src1); - } - template - void add(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::add, getDataType
(), mod, dst, src0, src1); - } - template - void addc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::addc, getDataType
(), (hardware >= HW::XeHPC) ? mod : (mod | AccWrEn), dst, src0, src1); - } - template - void addc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::addc, getDataType
(), (hardware >= HW::XeHPC) ? mod : (mod | AccWrEn), dst, src0, src1); - } - template - void add3(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(Opcode::add3, getDataType
(), mod, dst, src0, src1, src2); - } - template - void add3(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const RegData &src2) { - opX(Opcode::add3, getDataType
(), mod, dst, src0, src1, src2); - } - template - void add3(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const Immediate &src2) { - opX(Opcode::add3, getDataType
(), mod, dst, src0, src1, src2); - } - template - void add3(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const Immediate &src2) { - opX(Opcode::add3, getDataType
(), mod, dst, src0, src1, src2); - } - template - void and_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::and_gen12 : Opcode::and_, getDataType
(), mod, dst, src0, src1); - } - template - void and_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::and_gen12 : Opcode::and_, getDataType
(), mod, dst, src0, src1); - } -#ifndef NGEN_NO_OP_NAMES - template - void and(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - and_
(mod, dst, src0, src1); - } - template - void and(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - and_
(mod, dst, src0, src1); - } -#endif - template - void asr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::asr_gen12 : Opcode::asr, getDataType
(), mod, dst, src0, src1); - } - template - void asr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::asr_gen12 : Opcode::asr, getDataType
(), mod, dst, src0, src1); - } - template - void avg(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::avg, getDataType
(), mod, dst, src0, src1); - } - template - void avg(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::avg, getDataType
(), mod, dst, src0, src1); - } - template - void bfe(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::bfe_gen12 : Opcode::bfe, getDataType
(), mod, dst, src0, src1, src2); - } - template - void bfe(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::bfe_gen12 : Opcode::bfe, getDataType
(), mod, dst, src0, src1, src2); - } - template - void bfe(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const Immediate &src2) { - opX(isGen12 ? Opcode::bfe_gen12 : Opcode::bfe, getDataType
(), mod, dst, src0, src1, src2); - } - template - void bfe(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const Immediate &src2) { - opX(isGen12 ? Opcode::bfe_gen12 : Opcode::bfe, getDataType
(), mod, dst, src0, src1, src2); - } - template - void bfi1(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::bfi1_gen12 : Opcode::bfi1, getDataType
(), mod, dst, src0, src1); - } - template - void bfi1(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::bfi1_gen12 : Opcode::bfi1, getDataType
(), mod, dst, src0, src1); - } - template - void bfi2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::bfi2_gen12 : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); - } - template - void bfi2(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::bfi2_gen12 : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); - } - template - void bfi2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const Immediate &src2) { - opX(isGen12 ? Opcode::bfi2_gen12 : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); - } - template - void bfi2(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const Immediate &src2) { - opX(isGen12 ? Opcode::bfi2_gen12 : Opcode::bfi2, getDataType
(), mod, dst, src0, src1, src2); - } - template - void bfn(const InstructionModifier &mod, uint8_t ctrl, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(Opcode::bfn, getDataType
(), mod, dst, src0, src1, src2, ctrl); - } - template - void bfn(const InstructionModifier &mod, uint8_t ctrl, const RegData &dst, const Immediate &src0, const RegData &src1, const RegData &src2) { - opX(Opcode::bfn, getDataType
(), mod, dst, src0, src1, src2, ctrl); - } - template - void bfn(const InstructionModifier &mod, uint8_t ctrl, const RegData &dst, const RegData &src0, const RegData &src1, const Immediate &src2) { - opX(Opcode::bfn, getDataType
(), mod, dst, src0, src1, src2, ctrl); - } - template - void bfn(const InstructionModifier &mod, uint8_t ctrl, const RegData &dst, const Immediate &src0, const RegData &src1, const Immediate &src2) { - opX(Opcode::bfn, getDataType
(), mod, dst, src0, src1, src2, ctrl); - } - template - void bfrev(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::bfrev_gen12 : Opcode::bfrev, getDataType
(), mod, dst, src0); - } - template - void bfrev(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(isGen12 ? Opcode::bfrev_gen12 : Opcode::bfrev, getDataType
(), mod, dst, src0); - } - void brc(const InstructionModifier &mod, Label &jip, Label &uip) { - (void) jip.getID(labelManager); - (void) uip.getID(labelManager); - opX(Opcode::brc, mod, jip, uip); - } - void brc(const InstructionModifier &mod, const RegData &src0) { - opCall(Opcode::brc, mod, NoOperand(), src0); - } - void brd(const InstructionModifier &mod, Label &jip) { - (void) jip.getID(labelManager); - opX(Opcode::brd, mod, jip); - } - void brd(const InstructionModifier &mod, const RegData &src0) { - opCall(Opcode::brd, mod, NoOperand(), src0); - } - void break_(const InstructionModifier &mod, Label &jip, Label &uip) { - (void) jip.getID(labelManager); - (void) uip.getID(labelManager); - opX(Opcode::break_, mod, jip, uip); - } - void call(const InstructionModifier &mod, const RegData &dst, Label &jip) { - (void) jip.getID(labelManager); - opCall(Opcode::call, mod, dst, jip); - } - void call(const InstructionModifier &mod, const RegData &dst, const RegData &jip) { - opCall(Opcode::call, mod, dst, jip); - } - void calla(const InstructionModifier &mod, const RegData &dst, int32_t jip) { - opCall(Opcode::calla, mod, dst, Immediate::ud(jip)); - } - void calla(const InstructionModifier &mod, const RegData &dst, const RegData &jip) { - opCall(Opcode::calla, mod, dst, jip); - } - template - void cbit(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(Opcode::cbit, getDataType
(), mod, dst, src0); - } - template - void cbit(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(Opcode::cbit, getDataType
(), mod, dst, src0); - } - template - void cmp(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::cmp_gen12 : Opcode::cmp, getDataType
(), mod, dst, src0, src1); - } - template - void cmp(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::cmp_gen12 : Opcode::cmp, getDataType
(), mod, dst, src0, src1); - } - template - void cmpn(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::cmpn_gen12 : Opcode::cmpn, getDataType
(), mod, dst, src0, src1); - } - template - void csel(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::csel_gen12 : Opcode::csel, getDataType
(), mod, dst, src0, src1, src2); - } - template - void csel(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const RegData &src2) { - opX(isGen12 ? Opcode::csel_gen12 : Opcode::csel, getDataType
(), mod, dst, src0, src1, src2); - } - template - void csel(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const Immediate &src2) { - opX(isGen12 ? Opcode::csel_gen12 : Opcode::csel, getDataType
(), mod, dst, src0, src1, src2); - } - template - void csel(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const Immediate &src2) { - opX(isGen12 ? Opcode::csel_gen12 : Opcode::csel, getDataType
(), mod, dst, src0, src1, src2); - } - void cont(const InstructionModifier &mod, Label &jip, Label &uip) { - (void) jip.getID(labelManager); - (void) uip.getID(labelManager); - opX(Opcode::cont, mod, jip, uip); - } - template - void dp2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::dp2, getDataType
(), mod, dst, src0, src1); - } - template - void dp2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::dp2, getDataType
(), mod, dst, src0, src1); - } - template - void dp3(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::dp3, getDataType
(), mod, dst, src0, src1); - } - template - void dp3(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::dp3, getDataType
(), mod, dst, src0, src1); - } - template - void dp4(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::dp4, getDataType
(), mod, dst, src0, src1); - } - template - void dp4(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::dp4, getDataType
(), mod, dst, src0, src1); - } - template - void dp4a(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(Opcode::dp4a, getDataType
(), mod, dst, src0, src1, src2); - } - template - void dp4a(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const RegData &src2) { - opX(Opcode::dp4a, getDataType
(), mod, dst, src0, src1, src2); - } - template - void dp4a(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const Immediate &src2) { - opX(Opcode::dp4a, getDataType
(), mod, dst, src0, src1, src2); - } - template - void dp4a(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const Immediate &src2) { - opX(Opcode::dp4a, getDataType
(), mod, dst, src0, src1, src2); - } - template - void dpas(const InstructionModifier &mod, uint8_t sdepth, uint8_t rcount, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opDpas(Opcode::dpas, getDataType
(), mod, sdepth, rcount, dst, src0, src1, src2); - } - template - void dpasw(const InstructionModifier &mod, uint8_t sdepth, uint8_t rcount, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opDpas(Opcode::dpasw, getDataType
(), mod, sdepth, rcount, dst, src0, src1, src2); - } - template - void dph(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::dph, getDataType
(), mod, dst, src0, src1); - } - template - void dph(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::dph, getDataType
(), mod, dst, src0, src1); - } - void else_(const InstructionModifier &mod, Label &jip, Label &uip, bool branchCtrl = false) { - (void) jip.getID(labelManager); - (void) uip.getID(labelManager); - opX(Opcode::else_, DataType::invalid, mod, NoOperand(), jip, uip, NoOperand(), branchCtrl); - } - void else_(InstructionModifier mod, Label &jip) { - else_(mod, jip, jip); - } - void endif(const InstructionModifier &mod, Label &jip) { - (void) jip.getID(labelManager); - opX(Opcode::endif, mod, NoOperand(), jip); - } - void endif(const InstructionModifier &mod) { - Label next; - endif(mod, next); - mark(next); - } - template - void fbh(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(Opcode::fbh, getDataType
(), mod, dst, src0); - } - template - void fbh(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(Opcode::fbh, getDataType
(), mod, dst, src0); - } - template - void fbl(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(Opcode::fbl, getDataType
(), mod, dst, src0); - } - template - void fbl(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(Opcode::fbl, getDataType
(), mod, dst, src0); - } - template - void frc(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(Opcode::frc, getDataType
(), mod, dst, src0); - } - void goto_(const InstructionModifier &mod, Label &jip, Label &uip, bool branchCtrl = false) { - (void) jip.getID(labelManager); - (void) uip.getID(labelManager); - opX(Opcode::goto_, DataType::invalid, mod, NoOperand(), jip, uip, NoOperand(), branchCtrl); - } - void goto_(const InstructionModifier &mod, Label &jip) { - goto_(mod, jip, jip); - } - void halt(const InstructionModifier &mod, Label &jip, Label &uip) { - (void) jip.getID(labelManager); - (void) uip.getID(labelManager); - opX(Opcode::halt, mod, jip, uip); - } - void halt(const InstructionModifier &mod, Label &jip) { - halt(mod, jip, jip); - } - void if_(const InstructionModifier &mod, Label &jip, Label &uip, bool branchCtrl = false) { - (void) jip.getID(labelManager); - (void) uip.getID(labelManager); - opX(Opcode::if_, DataType::invalid, mod, NoOperand(), jip, uip, NoOperand(), branchCtrl); - } - void if_(const InstructionModifier &mod, Label &jip) { - if_(mod, jip, jip); - } - void illegal() { - opX(Opcode::illegal); - } - void join(const InstructionModifier &mod, Label &jip) { - opX(Opcode::join, mod, jip); - } - void join(const InstructionModifier &mod) { - Label next; - join(mod, next); - mark(next); - } - void jmpi(const InstructionModifier &mod, Label &jip) { - (void) jip.getID(labelManager); - opJmpi(Opcode::jmpi, mod, jip); - } - void jmpi(const InstructionModifier &mod, const RegData &jip) { - opJmpi(Opcode::jmpi, mod, jip); - } - template - void line(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::line, getDataType
(), mod, dst, src0, src1); - } - template - void line(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::line, getDataType
(), mod, dst, src0, src1); - } - template - void lrp(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(Opcode::lrp, getDataType
(), mod, dst, src0, src1, src2); - } - template - void lzd(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(Opcode::lzd, getDataType
(), mod, dst, src0); - } - template - void lzd(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(Opcode::lzd, getDataType
(), mod, dst, src0); - } - template - void mac(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::mac, getDataType
(), mod, dst, src0, src1); - } - template - void mac(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::mac, getDataType
(), mod, dst, src0, src1); - } - template - void mach(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::mach, getDataType
(), (hardware >= HW::XeHPC) ? mod : (mod | AccWrEn), dst, src0, src1); - } - template - void mach(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::mach, getDataType
(), (hardware >= HW::XeHPC) ? mod : (mod | AccWrEn), dst, src0, src1); - } - template - void macl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { -#ifdef NGEN_SAFE - if (hardware < HW::Gen10) unsupported(); -#endif - opX((hardware >= HW::XeHPC) ? Opcode::macl : Opcode::mach, getDataType
(), mod, dst, src0, src1); - } - template - void macl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { -#ifdef NGEN_SAFE - if (hardware < HW::Gen10) unsupported(); -#endif - opX((hardware >= HW::XeHPC) ? Opcode::macl : Opcode::mach, getDataType
(), mod, dst, src0, src1); - } - template - void mad(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &src2) { - opX(Opcode::mad, getDataType
(), mod, dst, src0, src1, src2); - } - template - void mad(const InstructionModifier &mod, const Align16Operand &dst, const Align16Operand &src0, const Align16Operand &src1, const Align16Operand &src2) { - opX(Opcode::mad, getDataType
(), mod, dst, src0, src1, src2); - } - template - void mad(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const RegData &src2) { - opX(Opcode::mad, getDataType
(), mod, dst, src0, src1, src2); - } - template - void mad(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const Immediate &src2) { - opX(Opcode::mad, getDataType
(), mod, dst, src0, src1, src2); - } - template - void mad(const InstructionModifier &mod, const RegData &dst, const Immediate &src0, const RegData &src1, const Immediate &src2) { - opX(Opcode::mad, getDataType
(), mod, dst, src0, src1, src2); - } - template - void madm(const InstructionModifier &mod, const ExtendedReg &dst, const ExtendedReg &src0, const ExtendedReg &src1, const ExtendedReg &src2) { - opX(Opcode::madm, getDataType
(), mod, dst, src0, src1, src2); - } - template - void math(const InstructionModifier &mod, MathFunction fc, const RegData &dst, const RegData &src0) { -#ifdef NGEN_SAFE - if (mathArgCount(fc) != 1) throw invalid_operand_count_exception(); -#endif - if (fc == MathFunction::rsqtm) - math
(mod, fc, dst | nomme, src0 | nomme); - else - opX(Opcode::math, getDataType
(), mod, dst, src0, NoOperand(), NoOperand(), static_cast(fc)); - } - template - void math(const InstructionModifier &mod, MathFunction fc, const RegData &dst, const RegData &src0, const RegData &src1) { -#ifdef NGEN_SAFE - if (mathArgCount(fc) != 2) throw invalid_operand_count_exception(); -#endif - if (fc == MathFunction::invm) - math
(mod, fc, dst | nomme, src0 | nomme, src1 | nomme); - else - opX(Opcode::math, getDataType
(), mod, dst, src0, src1, NoOperand(), static_cast(fc)); - } - template - void math(const InstructionModifier &mod, MathFunction fc, const RegData &dst, const RegData &src0, const Immediate &src1) { -#ifdef NGEN_SAFE - if (fc == MathFunction::invm || fc == MathFunction::rsqtm) throw invalid_operand_exception(); -#endif - opX(Opcode::math, getDataType
(), mod, dst, src0, src1.forceInt32(), NoOperand(), static_cast(fc)); - } - template - void math(InstructionModifier mod, MathFunction fc, const ExtendedReg &dst, const ExtendedReg &src0) { -#ifdef NGEN_SAFE - if (fc != MathFunction::rsqtm) throw invalid_operand_exception(); -#endif - mod.setCMod(ConditionModifier::eo); - opX(Opcode::math, getDataType
(), mod, dst, src0, NoOperand(), NoOperand(), static_cast(fc)); - } - template - void math(InstructionModifier mod, MathFunction fc, const ExtendedReg &dst, const ExtendedReg &src0, const ExtendedReg &src1) { -#ifdef NGEN_SAFE - if (fc != MathFunction::invm) throw invalid_operand_exception(); -#endif - mod.setCMod(ConditionModifier::eo); - opX(Opcode::math, getDataType
(), mod, dst, src0, src1, NoOperand(), static_cast(fc)); - } - template - void mov(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::mov_gen12 : Opcode::mov, getDataType
(), mod, dst, src0); - } - template - void mov(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(isGen12 ? Opcode::mov_gen12 : Opcode::mov, getDataType
(), mod, dst, src0); - } - template - void movi(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - if (hardware >= HW::Gen10) - movi
(mod, dst, src0, null); - else - opX(isGen12 ? Opcode::movi_gen12 : Opcode::movi, getDataType
(), mod, dst, src0); - } - template - void movi(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { -#ifdef NGEN_SAFE - if (hardware < HW::Gen10) throw unsupported_instruction(); -#endif - opX(isGen12 ? Opcode::movi_gen12 : Opcode::movi, getDataType
(), mod, dst, src0, src1); - } - template - void movi(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { -#ifdef NGEN_SAFE - if (hardware < HW::Gen10) throw unsupported_instruction(); -#endif - opX(isGen12 ? Opcode::movi_gen12 : Opcode::movi, getDataType
(), mod, dst, src0, src1); - } - template - void mul(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::mul, getDataType
(), mod, dst, src0, src1); - } - template - void mul(const InstructionModifier &mod, const RegData &dst, const RegData &src0, Immediate src1) { - if (dst.getBytes() == 8) - src1 = src1.forceInt32(); - opX(Opcode::mul, getDataType
(), mod, dst, src0, src1); - } - void nop() { - opX(isGen12 ? Opcode::nop_gen12 : Opcode::nop); - } - template - void not_(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(isGen12 ? Opcode::not_gen12 : Opcode::not_, getDataType
(), mod, dst, src0); - } - template - void not_(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(isGen12 ? Opcode::not_gen12 : Opcode::not_, getDataType
(), mod, dst, src0); - } -#ifndef NGEN_NO_OP_NAMES - template - void not(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - not_
(mod, dst, src0); - } - template - void not(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - not_
(mod, dst, src0); - } -#endif - template - void or_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::or_gen12 : Opcode::or_, getDataType
(), mod, dst, src0, src1); - } - template - void or_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::or_gen12 : Opcode::or_, getDataType
(), mod, dst, src0, src1); - } -#ifndef NGEN_NO_OP_NAMES - template - void or(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - or_
(mod, dst, src0, src1); - } - template - void or(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - or_
(mod, dst, src0, src1); - } -#endif - template - void pln(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::pln, getDataType
(), mod, dst, src0, src1); - } - void ret(const InstructionModifier &mod, const RegData &src0) { - opJmpi(Opcode::ret, mod, src0); - } - template - void rndd(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(Opcode::rndd, getDataType
(), mod, dst, src0); - } - template - void rndd(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(Opcode::rndd, getDataType
(), mod, dst, src0); - } - template - void rnde(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(Opcode::rnde, getDataType
(), mod, dst, src0); - } - template - void rnde(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(Opcode::rnde, getDataType
(), mod, dst, src0); - } - template - void rndu(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(Opcode::rndu, getDataType
(), mod, dst, src0); - } - template - void rndu(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(Opcode::rndu, getDataType
(), mod, dst, src0); - } - template - void rndz(const InstructionModifier &mod, const RegData &dst, const RegData &src0) { - opX(Opcode::rndz, getDataType
(), mod, dst, src0); - } - template - void rndz(const InstructionModifier &mod, const RegData &dst, const Immediate &src0) { - opX(Opcode::rndz, getDataType
(), mod, dst, src0); - } - template - void rol(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::rol_gen12 : Opcode::rol, getDataType
(), mod, dst, src0, src1); - } - template - void rol(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::rol_gen12 : Opcode::rol, getDataType
(), mod, dst, src0, src1); - } - template - void ror(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::ror_gen12 : Opcode::ror, getDataType
(), mod, dst, src0, src1); - } - template - void ror(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::ror_gen12 : Opcode::ror, getDataType
(), mod, dst, src0, src1); - } - template - void sad2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::sad2, getDataType
(), mod, dst, src0, src1); - } - template - void sad2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::sad2, getDataType
(), mod, dst, src0, src1); - } - template - void sada2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::sada2, getDataType
(), mod, dst, src0, src1); - } - template - void sada2(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::sada2, getDataType
(), mod, dst, src0, src1); - } - template - void sel(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::sel_gen12 : Opcode::sel, getDataType
(), mod, dst, src0, src1); - } - template - void sel(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::sel_gen12 : Opcode::sel, getDataType
(), mod, dst, src0, src1); - } - - /* Gen12-style sends */ - void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, Immediate::ud(exdesc), Immediate::ud(desc)); - } - void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, exdesc, Immediate::ud(desc)); - } - void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const GRFRange &src1, const RegData &exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, exdesc, Immediate::ud(desc)); - } - void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, uint32_t exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, NoOperand(), Immediate::ud(exdesc), Immediate::ud(desc)); - } - void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, NoOperand(), exdesc, Immediate::ud(desc)); - } - void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, Immediate::ud(exdesc), desc); - } - void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, exdesc, desc); - } - void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const GRFRange &src1, const RegData &exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, src1, exdesc, desc); - } - void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, uint32_t exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, NoOperand(), Immediate::ud(exdesc), desc); - } - void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::send : Opcode::sends, mod, sf, dst, src0, NoOperand(), exdesc, desc); - } - void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, Immediate::ud(exdesc), Immediate::ud(desc)); - } - void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, exdesc, Immediate::ud(desc)); - } - void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const GRFRange &src1, const RegData &exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, exdesc, Immediate::ud(desc)); - } - void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, uint32_t exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, NoOperand(), Immediate::ud(exdesc), Immediate::ud(desc)); - } - void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &exdesc, uint32_t desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, NoOperand(), exdesc, Immediate::ud(desc)); - } - void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, Immediate::ud(exdesc), desc); - } - void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, exdesc, desc); - } - void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const GRFRange &src1, const RegData &exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, src1, exdesc, desc); - } - void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, uint32_t exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, NoOperand(), Immediate::ud(exdesc), desc); - } - void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, const RegData &exdesc, const RegData &desc) { - opSend(isGen12 ? Opcode::sendc : Opcode::sendsc, mod, sf, dst, src0, NoOperand(), exdesc, desc); - } - template void send(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, NoOperand src1, T1 exdesc, T2 desc) { - opSend(Opcode::send, mod, sf, dst, src0, src1, exdesc, desc); - } - template void sendc(const InstructionModifier &mod, SharedFunction sf, const RegData &dst, const RegData &src0, NoOperand src1, T1 exdesc, T2 desc) { - opSend(Opcode::sendc, mod, sf, dst, src0, src1, exdesc, desc); - } - /* Pre-Gen12 style sends */ - void send(const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, uint32_t desc) { - if (isGen12) - send(mod, static_cast(exdesc & 0xF), dst, src0, null, exdesc, desc); - else - send(mod, SharedFunction::null, dst, src0, NoOperand(), Immediate::ud(exdesc), Immediate::ud(desc)); - } - void send(const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, const RegData &desc) { - if (isGen12) - send(mod, static_cast(exdesc & 0xF), dst, src0, null, exdesc, desc); - else - send(mod, SharedFunction::null, dst, src0, NoOperand(), Immediate::ud(exdesc), desc); - } - void sendc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, uint32_t desc) { - if (isGen12) - sendc(mod, static_cast(exdesc & 0xF), dst, src0, null, exdesc, desc); - else - sendc(mod, SharedFunction::null, dst, src0, NoOperand(), Immediate::ud(exdesc), Immediate::ud(desc)); - } - void sendc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, uint32_t exdesc, const RegData &desc) { - if (isGen12) - sendc(mod, static_cast(exdesc & 0xF), dst, src0, null, exdesc, desc); - else - sendc(mod, SharedFunction::null, dst, src0, NoOperand(), Immediate::ud(exdesc), desc); - } - void sends(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, uint32_t desc) { - send(mod, static_cast(exdesc & 0xF), dst, src0, src1, exdesc, desc); - } - void sends(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, const RegData &desc) { - send(mod, static_cast(exdesc & 0xF), dst, src0, src1, exdesc, desc); - } - void sends(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, uint32_t desc) { -#ifdef NGEN_SAFE - if (isGen12) throw sfid_needed_exception(); -#endif - send(mod, static_cast(0), dst, src0, src1, exdesc, desc); - } - void sends(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, const RegData &desc) { -#ifdef NGEN_SAFE - if (isGen12) throw sfid_needed_exception(); -#endif - send(mod, static_cast(0), dst, src0, src1, exdesc, desc); - } - void sendsc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, uint32_t desc) { - sendc(mod, static_cast(exdesc & 0xF), dst, src0, src1, exdesc, desc); - } - void sendsc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, uint32_t exdesc, const RegData &desc) { - sendc(mod, static_cast(exdesc & 0xF), dst, src0, src1, exdesc, desc); - } - void sendsc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, uint32_t desc) { -#ifdef NGEN_SAFE - if (isGen12) throw sfid_needed_exception(); -#endif - sendc(mod, static_cast(0), dst, src0, src1, exdesc, desc); - } - void sendsc(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1, const RegData &exdesc, const RegData &desc) { -#ifdef NGEN_SAFE - if (isGen12) throw sfid_needed_exception(); -#endif - sendc(mod, static_cast(0), dst, src0, src1, exdesc, desc); - } - - template - void shl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::shl_gen12 : Opcode::shl, getDataType
(), mod, dst, src0, src1); - } - template - void shl(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::shl_gen12 : Opcode::shl, getDataType
(), mod, dst, src0, src1); - } - template - void shr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::shr_gen12 : Opcode::shr, getDataType
(), mod, dst, src0, src1); - } - template - void shr(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::shr_gen12 : Opcode::shr, getDataType
(), mod, dst, src0, src1); - } - template - void smov(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::smov_gen12 : Opcode::smov, getDataType
(), mod, dst, src0, src1); - } - template - void srnd(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::srnd, getDataType
(), mod, dst, src0, src1); - } - template - void srnd(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::srnd, getDataType
(), mod, dst, src0, src1); - } - template - void subb(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(Opcode::subb, getDataType
(), (hardware >= HW::XeHPC) ? mod : (mod | AccWrEn), dst, src0, src1); - } - template - void subb(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(Opcode::subb, getDataType
(), (hardware >= HW::XeHPC) ? mod : (mod | AccWrEn), dst, src0, src1); - } - void wait(const InstructionModifier &mod, const RegData &nreg) { - opX(Opcode::wait, mod, NoOperand(), nreg); - } - void while_(const InstructionModifier &mod, Label &jip) { - (void) jip.getID(labelManager); - opX(Opcode::while_, mod, jip); - } - template - void xor_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - opX(isGen12 ? Opcode::xor_gen12 : Opcode::xor_, getDataType
(), mod, dst, src0, src1); - } - template - void xor_(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - opX(isGen12 ? Opcode::xor_gen12 : Opcode::xor_, getDataType
(), mod, dst, src0, src1); - } -#ifndef NGEN_NO_OP_NAMES - template - void xor(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const RegData &src1) { - xor_
(mod, dst, src0, src1); - } - template - void xor(const InstructionModifier &mod, const RegData &dst, const RegData &src0, const Immediate &src1) { - xor_
(mod, dst, src0, src1); - } -#endif - -private: - struct Sync { - AsmCodeGenerator &parent; - - Sync(AsmCodeGenerator *parent_) : parent(*parent_) {} - - void operator()(SyncFunction fc, const InstructionModifier &mod = InstructionModifier()) { - parent.opSync(Opcode::sync, fc, mod, null); - } - void operator()(SyncFunction fc, const RegData &src0) { - this->operator()(fc, InstructionModifier(), src0); - } - void operator()(SyncFunction fc, const InstructionModifier &mod, const RegData &src0) { - parent.opSync(Opcode::sync, fc, mod, src0); - } - void operator()(SyncFunction fc, int src0) { - this->operator()(fc, InstructionModifier(), src0); - } - void operator()(SyncFunction fc, const InstructionModifier &mod, int src0) { - parent.opSync(Opcode::sync, fc, mod, Immediate::ud(src0)); - } - void allrd() { - allrd(null); - } - void allrd(const InstructionModifier &mod) { - allrd(mod, null); - } - void allrd(const RegData &src0) { - allrd(InstructionModifier(), src0); - } - void allrd(const InstructionModifier &mod, const RegData &src0) { - this->operator()(SyncFunction::allrd, mod, src0); - } - void allrd(uint32_t src0) { - allrd(InstructionModifier(), src0); - } - void allrd(const InstructionModifier &mod, uint32_t src0) { - this->operator()(SyncFunction::allrd, mod, src0); - } - void allwr() { - allwr(null); - } - void allwr(const InstructionModifier &mod) { - allwr(mod, null); - } - void allwr(const RegData &src0) { - allwr(InstructionModifier(), src0); - } - void allwr(const InstructionModifier &mod, const RegData &src0) { - this->operator()(SyncFunction::allwr, mod, src0); - } - void allwr(uint32_t src0) { - allwr(InstructionModifier(), src0); - } - void allwr(const InstructionModifier &mod, uint32_t src0) { - this->operator()(SyncFunction::allwr, mod, src0); - } - void bar(const InstructionModifier &mod = InstructionModifier()) { - this->operator()(SyncFunction::bar, mod); - } - void bar(const InstructionModifier &mod, uint32_t src0) { - this->operator()(SyncFunction::bar, mod, src0); - } - void bar(const InstructionModifier &mod, const RegData &src0) { - this->operator()(SyncFunction::bar, mod, src0); - } - void bar(uint32_t src0) { - this->operator()(SyncFunction::bar, InstructionModifier(), src0); - } - void bar(const RegData &src0) { - this->operator()(SyncFunction::bar, InstructionModifier(), src0); - } - void flush() { - flush(InstructionModifier()); - } - void flush(const InstructionModifier &mod) { - this->operator()(SyncFunction::flush, InstructionModifier(), null); - } - void host(const InstructionModifier &mod = InstructionModifier()) { - this->operator()(SyncFunction::host, mod); - } - void nop(const InstructionModifier &mod = InstructionModifier()) { - this->operator()(SyncFunction::nop, mod); - } - }; -public: - Sync sync; - - void ignoredep(Operand op) { - if (hardware >= HW::Gen12LP) - opX(Opcode::directive, DataType::ud, InstructionModifier(), GRF(static_cast(op)), NoOperand()); - } - void subdep(Operand op, const GRFRange &r) { - ignoredep(op); - wrdep(r); - } - void subdep(Operand op, const GRF &r) { - ignoredep(op); - wrdep(r); - } - void wrdep(const GRFRange &r) { -#ifdef NGEN_SAFE - if (hardware < HW::Gen12LP) throw unsupported_instruction(); -#endif - int len = r.getLen(); - for (int o = 0; o < len; o += 32) { - int thisLen = std::min(len - o, 32); - opX(Opcode::directive, DataType::ud, InstructionModifier::createAutoSWSB(), GRF(static_cast(Directive::wrdep)), r[o] - r[o + thisLen - 1]); - } - } - void wrdep(const GRF &r) { - wrdep(r-r); - } - void fencedep(Label &fenceLocation) { - opX(Opcode::directive, DataType::ud, InstructionModifier::createAutoSWSB(), GRF(static_cast(Directive::fencedep)), fenceLocation); - } - - inline void mark(Label &label) { streamStack.back()->mark(label, labelManager); } - - using _self = AsmCodeGenerator; - -#include "ngen_pseudo.hpp" -#ifndef NGEN_GLOBAL_REGS -#include "ngen_registers.hpp" -#endif -}; - - -void AsmCodeGenerator::unsupported() -{ -#ifdef NGEN_SAFE - throw unsupported_instruction(); -#endif -} - -AsmCodeGenerator::InstructionStream *AsmCodeGenerator::popStream() -{ -#ifdef NGEN_SAFE - if (streamStack.size() <= 1) throw stream_stack_underflow(); -#endif - - InstructionStream *result = streamStack.back(); - streamStack.pop_back(); - return result; -} - -void AsmCodeGenerator::finalize() -{ -#ifdef NGEN_SAFE - if (streamStack.size() > 1) throw unfinished_stream_exception(); -#endif - auto &buffer = streamStack.back()->buffer; - int inum = 0; - for (auto &i : buffer) - i.inum = inum++; -} - -void AsmCodeGenerator::getCode(std::ostream &out) -{ - finalize(); - - autoswsb::BasicBlockList analysis = autoswsb::autoSWSB(hardware, declaredGRFs, streamStack.back()->buffer); - std::multimap syncs; // Syncs inserted by auto-SWSB. - - for (auto &bb : analysis) - for (auto &sync : bb.syncs) - syncs.insert(std::make_pair(sync.inum, &sync)); - - auto nextSync = syncs.begin(); - int lineNo = 0; - - for (auto &i : streamStack.back()->buffer) { - while ((nextSync != syncs.end()) && (nextSync->second->inum == i.inum)) - outX(out, *(nextSync++)->second, lineNo++); - - if (i.isLabel()) { - i.dst.label.outputText(out, PrintDetail::full, labelManager); - out << ':' << std::endl; - if (i.dst.label == _labelLocalIDsLoaded) - lineNo = 0; - } else if (i.isComment()) - out << "// " << i.comment << std::endl; - else if (i.op != Opcode::directive) - outX(out, i, lineNo++); - } -} - -template -void AsmCodeGenerator::opX(Opcode op, DataType defaultType, const InstructionModifier &mod, D dst, S0 src0, S1 src1, S2 src2, uint16_t ext) -{ - bool is2Src = !S1::emptyOp; - bool is3Src = !S2::emptyOp; - int arity = 1 + is2Src + is3Src; - - InstructionModifier emod = mod | defaultModifier; - auto esize = emod.getExecSize(); - - if (is3Src && hardware < HW::Gen10) - esize = std::min(esize, 8); // WA for IGA Align16 emulation issue - -#ifdef NGEN_SAFE - if (esize > 1 && dst.isScalar()) - throw invalid_execution_size_exception(); -#endif - - auto ewidth = getExecWidth({defaultType, dst.getType(), src0.getType(), src1.getType(), src2.getType()}); - dst.fixup(hardware, esize, ewidth, defaultType, -1, arity); - src0.fixup(hardware, esize, ewidth, defaultType, 0, arity); - src1.fixup(hardware, esize, ewidth, defaultType, 1, arity); - src2.fixup(hardware, esize, ewidth, defaultType, 2, arity); - - streamStack.back()->append(op, ext, emod, &labelManager, dst, src0, src1, src2); -} - -void AsmCodeGenerator::outX(std::ostream &out, const AsmInstruction &i, int lineNo) -{ - bool ternary = (i.src[2].type != AsmOperand::Type::none); - PrintDetail ddst = PrintDetail::hs; - PrintDetail dsrc01 = ternary ? PrintDetail::vs_hs : PrintDetail::full; - PrintDetail dsrc[5] = {dsrc01, dsrc01, PrintDetail::hs, PrintDetail::base, PrintDetail::base}; - - switch (i.op) { - case Opcode::send: - case Opcode::sends: - case Opcode::sendc: - case Opcode::sendsc: - ddst = dsrc[0] = dsrc[1] = PrintDetail::base; - dsrc[2] = dsrc[3] = PrintDetail::sub_no_type; - break; - case Opcode::brc: - case Opcode::brd: - case Opcode::call: - case Opcode::calla: - ddst = PrintDetail::sub; - dsrc[0] = PrintDetail::sub_no_type; - break; - case Opcode::jmpi: - case Opcode::ret: - dsrc[0] = PrintDetail::sub_no_type; - break; - case Opcode::dpas: - case Opcode::dpasw: - if (isGen12) ddst = dsrc[0] = dsrc[1] = dsrc[2] = PrintDetail::sub; - break; - case Opcode::sync: - if (isGen12) { - if (i.src[0].type == AsmOperand::Type::reg) - dsrc[0] = PrintDetail::sub; - else - dsrc[0] = PrintDetail::sub_no_type; - } - break; - default: break; - } - - outMods(out, i.mod, i.op, ModPlacementType::Pre); - - out << getMnemonic(i.op, hardware); - outExt(out, i); - out << '\t'; - - outMods(out, i.mod, i.op, ModPlacementType::Mid); - - i.dst.outputText(out, ddst, labelManager); out << '\t'; - for (int n = 0; n <= 4; n++) { - i.src[n].outputText(out, dsrc[n], labelManager); - bool showLen = false; - if (i.ext & 0x80) { - showLen |= (n == 1 && (i.op == Opcode::send || i.op == Opcode::sendc) && hardware >= HW::XeHPG); - } - - if (showLen) - out << ':' << (i.ext >> 8); - out << '\t'; - } - - outMods(out, i.mod, i.op, ModPlacementType::Post); - if (lineNumbers) - out << "\t// " << lineNo * 2; - out << std::endl; -} - -void AsmCodeGenerator::outExt(std::ostream &out, const AsmInstruction &i) -{ - switch (i.opcode()) { - case Opcode::else_: - case Opcode::goto_: - case Opcode::if_: if (i.ext) out << ".b"; break; - case Opcode::math: out << '.' << static_cast(i.ext); break; - default: break; - } - - if (isGen12) switch (i.opcode()) { - case Opcode::send: - case Opcode::sendc: - case Opcode::sends: - case Opcode::sendsc: out << '.' << getMnemonic(static_cast(i.ext & 0xF), hardware); break; - case Opcode::sync: out << '.' << static_cast(i.ext); break; - case Opcode::bfn: out << ".0x" << std::hex << i.ext << std::dec; break; - case Opcode::dpas: - case Opcode::dpasw: { - int sdepth = i.ext >> 8; - int rcount = i.ext & 0xFF; - out << '.' << sdepth << 'x' << rcount; - } - default: break; - } -} - -void AsmCodeGenerator::outMods(std::ostream &out,const InstructionModifier &mod, Opcode op, AsmCodeGenerator::ModPlacementType location) -{ - ConditionModifier cmod = mod.getCMod(); - PredCtrl ctrl = mod.getPredCtrl(); - bool wrEn = mod.isWrEn(); - bool havePred = (ctrl != PredCtrl::None) && (cmod != ConditionModifier::eo); - - switch (location) { - case ModPlacementType::Pre: - if (wrEn || havePred) { - out << '('; - if (wrEn) { - out << 'W'; - if (havePred) out << '&'; - } - if (havePred) { - if (mod.isPredInv()) out << '~'; - mod.getFlagReg().outputText(out, PrintDetail::sub_no_type, labelManager); - if (ctrl != PredCtrl::Normal) - out << '.' << toText(ctrl, mod.isAlign16()); - } - out << ')'; - } - out << '\t'; - break; - case ModPlacementType::Mid: - if (mod.getExecSize() > 0) - out << '(' << mod.getExecSize() << "|M" << mod.getChannelOffset() << ')' << '\t'; - - if (cmod != ConditionModifier::none) { - out << '(' << cmod << ')'; - mod.getFlagReg().outputText(out, PrintDetail::sub_no_type, labelManager); - out << '\t'; - } - - if (mod.isSaturate()) out << "(sat)"; - break; - case ModPlacementType::Post: - { - bool havePostMod = false; - auto startPostMod = [&]() { - out << (havePostMod ? ',' : '{'); - havePostMod = true; - }; - auto printPostMod = [&](const char *name) { - startPostMod(); out << name; - }; - - SWSBInfo swsb = mod.getSWSB(); - if (swsb.hasToken()) { - startPostMod(); out << '$' << swsb.parts.token; - if (swsb.parts.src && !swsb.parts.dst) out << ".src"; - if (swsb.parts.dst && !swsb.parts.src) out << ".dst"; - } - if (swsb.hasDist()) { - startPostMod(); - if (hardware > HW::Gen12LP && (op == Opcode::send || op == Opcode::sendc) && swsb.getPipe() == Pipe::Default) - out << Pipe::A; - else if (hardware > HW::Gen12LP || !swsb.hasToken()) - out << swsb.getPipe(); - out << '@' << swsb.parts.dist; - } - - if (swsb.parts.noacc) printPostMod("NoAccSBSet"); - if (mod.isAlign16()) printPostMod("Align16"); - if (mod.isNoDDClr()) printPostMod("NoDDClr"); - if (mod.isNoDDChk()) printPostMod("NoDDChk"); - if (mod.getThreadCtrl() == ThreadCtrl::Atomic) printPostMod("Atomic"); - if (!isGen12 && mod.getThreadCtrl() == ThreadCtrl::Switch) printPostMod("Switch"); - if (!isGen12 && mod.getThreadCtrl() == ThreadCtrl::NoPreempt) printPostMod("NoPreempt"); - if (mod.isAccWrEn() && hardware < HW::XeHPC) printPostMod("AccWrEn"); - if (mod.isCompact()) printPostMod("Compact"); - if (mod.isBreakpoint()) printPostMod("Breakpoint"); - if (mod.isSerialized()) printPostMod("Serialize"); - if (mod.isEOT()) printPostMod("EOT"); - if (mod.isExBSO()) printPostMod("ExBSO"); - - if (havePostMod) out << '}'; - } - break; - } -} - -} /* namespace NGEN_NAMESPACE */ - -#endif diff --git a/src/gpu/intel/jit/ngen/ngen_core.hpp b/src/gpu/intel/jit/ngen/ngen_core.hpp index d6960f92042..abf7a356c55 100644 --- a/src/gpu/intel/jit/ngen/ngen_core.hpp +++ b/src/gpu/intel/jit/ngen/ngen_core.hpp @@ -32,10 +32,6 @@ #endif #endif -#ifdef NGEN_ASM -#include -#endif - #ifdef NGEN_SAFE #include #endif @@ -323,16 +319,6 @@ enum class DataType : uint8_t { invalid = 0x60 }; -#ifdef NGEN_ASM -static inline std::ostream &operator<<(std::ostream &str, DataType type) -{ - static const char *names[32] = {"ud", "d", "uw", "w", "ub", "b", "df", "f", "uq", "q", "hf", "bf", "bf8", "uv", "v", "vf", - "tf32", "", "", "", "", "", "", "", "", "", "", "", "u4", "s4", "u2", "s2"}; - str << names[static_cast(type) & 0x1F]; - return str; -} -#endif - static inline constexpr int getLog2Bits(DataType type) { return static_cast(type) >> 5; } static inline constexpr14 int getLog2Bytes(DataType type) { return std::max(getLog2Bits(type) - 3, 0); } static inline constexpr14 int getLog2Dwords(DataType type) { return std::max(getLog2Bits(type) - 5, 0); } @@ -409,15 +395,6 @@ static inline int mathArgCount(MathFunction func) return argCounts[static_cast(func) & 0xF]; } -#ifdef NGEN_ASM -static inline std::ostream &operator<<(std::ostream &str, MathFunction func) -{ - static const char *names[16] = {"", "inv", "log", "exp", "sqt", "rsqt", "sin", "cos", "", "fdiv", "pow", "idiv", "iqot", "irem", "invm", "rsqtm"}; - str << names[static_cast(func) & 0xF]; - return str; -} -#endif - static inline bool hasIEEEMacro(HW hw) { if (hw == HW::Gen11) return false; if (hw == HW::Gen12LP) return false; @@ -435,15 +412,6 @@ enum class SyncFunction : uint8_t { host = 15 }; -#ifdef NGEN_ASM -static inline std::ostream &operator<<(std::ostream &str, SyncFunction func) -{ - static const char *names[16] = {"nop", "", "allrd", "allwr", "", "", "", "", "", "", "", "", "flush", "", "bar", "host"}; - str << names[static_cast(func) & 0xF]; - return str; -} -#endif - // Shared function IDs (SFIDs). enum class SharedFunction : uint8_t { null = 0x0, @@ -473,22 +441,6 @@ enum class SharedFunction : uint8_t { spawner = ts, }; -#ifdef NGEN_ASM -static inline const char *getMnemonic(SharedFunction sfid, HW hw) -{ - static const char *names[16] = { - "null", "" , "smpl", "gtwy", "dc2", "rc" , "urb", "ts" , - "vme" , "dcro", "dc0" , "pixi", "dc1", "cre", "" , "" , - }; - static const char *namesLSC[16] = { - "null", "ugml", "smpl", "gtwy", "dc2", "rc" , "urb", "btd", - "rta" , "dcro", "dc0" , "pixi", "dc1", "tgm", "slm", "ugm", - }; - const auto &table = (hw >= HW::XeHPG) ? namesLSC : names; - return table[static_cast(sfid) & 0xF]; -} -#endif - // ARFs: high nybble of register # specifies type enum class ARFType : uint8_t { null = 0, @@ -508,18 +460,6 @@ enum class ARFType : uint8_t { dbg = 15, }; -#ifdef NGEN_ASM -static inline std::ostream &operator<<(std::ostream &str, ARFType type) -{ - static const char *names[32] = {"null", "a", "acc", "f", "ce", "msg", "sp", "sr", "cr", "n", "ip", "tdr", "tm", "fc", "", "dbg", - "" , "" , "", "", "", "", "", "", "", "", "", "", "", "", "", ""}; - str << names[static_cast(type) & 0x1F]; - return str; -} - -enum class PrintDetail {base = 0, sub_no_type = 1, sub = 2, hs = 3, vs_hs = 4, full = 5}; -#endif - // Invalid singleton class. Can be assigned to nGEN objects to invalidate them. static constexpr class Invalid {} invalid{}; @@ -591,16 +531,6 @@ class Label { void fixup(HW hw, int execSize, int execWidth, DataType defaultType, int srcN, int arity) {} constexpr DataType getType() const { return DataType::invalid; } constexpr bool isScalar() const { return false; } - -#ifdef NGEN_ASM - static const bool emptyOp = false; - inline void outputText(std::ostream &str, PrintDetail detail, LabelManager &man); - - friend inline bool operator==(const Label &r1, const Label &r2) { - return !std::memcmp(&r1, &r2, sizeof(Label)); - } - friend inline bool operator!=(const Label &r1, const Label &r2) { return !(r1 == r2); } -#endif }; static inline bool operator==(const RegData &r1, const RegData &r2); @@ -626,10 +556,6 @@ class RegData { : base(base_), arf(arf_), off(off_), mods(0), type(static_cast(type_)), indirect(indirect_), vs(vs_), width(width_), hs(hs_), _pad2(0), invalid(0) {} public: -#ifdef NGEN_ASM - static const bool emptyOp = false; -#endif - constexpr RegData() : base(0), arf(0), off(0), mods(0), type(0), indirect(0), vs(0), width(0), hs(0), _pad2(0), invalid(1) {} @@ -684,10 +610,6 @@ class RegData { friend inline bool operator!=(const RegData &r1, const RegData &r2); friend inline RegData abs(const RegData &r); - -#ifdef NGEN_ASM - inline void outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const; -#endif }; static_assert(sizeof(RegData) == 8, "RegData structure is not laid out correctly in memory."); @@ -795,11 +717,6 @@ class Align16Operand { void fixup(HW hw, int execSize, int execWidth, DataType defaultType, int srcN, int arity) { rd.fixup(hw, execSize, execWidth, defaultType, srcN, arity); } - -#ifdef NGEN_ASM - inline void outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const; - static const bool emptyOp = false; -#endif }; // Register regions. @@ -1149,11 +1066,6 @@ class ExtendedReg { constexpr14 RegData &getBase() { return base; } constexpr RegData getBase() const { return base; } constexpr uint8_t getMMENum() const { return mmeNum; } - -#ifdef NGEN_ASM - inline void outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const; - static const bool emptyOp = false; -#endif }; static inline ExtendedReg operator|(const RegData &base, const SpecialAccumulatorRegister &acc) @@ -1390,11 +1302,6 @@ class GRFRange { void fixup(HW hw, int execSize, int execWidth, DataType defaultType, int srcN, int arity) {} constexpr DataType getType() const { return DataType::invalid; } - -#ifdef NGEN_ASM - static const bool emptyOp = false; - inline void outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const; -#endif }; static inline GRFRange operator-(const GRF ®1, const GRF ®2) @@ -1439,15 +1346,6 @@ enum class ConditionModifier { eo = 0xF }; -#ifdef NGEN_ASM -static inline std::ostream &operator<<(std::ostream &str, ConditionModifier cmod) -{ - static const char *names[16] = {"", "eq", "ne", "gt", "ge", "lt", "le", "", "ov", "un", "", "", "", "", "", "eo"}; - str << names[static_cast(cmod) & 0xF]; - return str; -} -#endif - enum class ChannelMask { rgba = 0, gba = 1, @@ -1489,14 +1387,6 @@ enum class PredCtrl { w = 5, }; -#ifdef NGEN_ASM -static const char *toText(PredCtrl ctrl, bool align16) { - const char *names[2][16] = {{"", "", "anyv", "allv", "any2h", "all2h", "any4h", "all4h", "any8h", "all8h", "any16h", "all16h", "any32h", "all32h", "any", "all"}, - {"", "", "x", "y", "z", "w", "", "", "", "", "", "", "", "", "", ""}}; - return names[align16][static_cast(ctrl) & 0xF]; -} -#endif - enum class ThreadCtrl { Normal = 0, Atomic = 1, @@ -1651,42 +1541,6 @@ static inline bool isBranch(Opcode op) return (static_cast(op) >> 4) == 2; } -#ifdef NGEN_ASM -static const char *getMnemonic(Opcode op, HW hw) -{ - const char *names[0x80] = { - "illegal", "sync", "sel", "movi", "not", "and", "or", "xor", - "shr", "shl", "smov", "", "asr", "", "ror", "rol", - "cmp", "cmpn", "csel", "", "", "", "", "bfrev", - "bfe", "bfi1", "bfi2", "", "", "", "", "", - "jmpi", "brd", "if", "brc", "else", "endif", "", "while", - "break", "cont", "halt", "calla", "call", "ret", "goto", "join", - "wait", "send", "sendc", "sends", "sendsc", "", "", "", - "math", "", "", "", "", "", "", "", - "add", "mul", "avg", "frc", "rndu", "rndd", "rnde", "rndz", - "mac", "mach", "lzd", "fbh", "fbl", "cbit", "addc", "subb", - "sad2", "sada2", "add3", "macl", "srnd", "dph", "dp3", "dp2", - "dp4a", "dpas", "dpasw", "mad", "lrp", "madm", "", "", - "nop", "mov", "sel", "movi", "not", "and", "or", "xor", - "shr", "shl", "smov", "bfn", "asr", "", "ror", "rol", - "cmp", "cmpn", "csel", "", "", "", "", "bfrev", - "bfe", "bfi1", "bfi2", "", "", "", "nop", "" - }; - - const char *mnemonic = names[static_cast(op) & 0x7F]; - - if (hw < HW::Gen12LP) switch (op) { - case Opcode::mov: mnemonic = "mov"; break; - case Opcode::line: mnemonic = "line"; break; - case Opcode::pln: mnemonic = "pln"; break; - case Opcode::dp4: mnemonic = "dp4"; break; - default: break; - } - - return mnemonic; -} -#endif - class AllPipes {}; enum class Pipe : uint8_t { Default = 0, @@ -1697,15 +1551,6 @@ enum class Pipe : uint8_t { M = 5, Math = M, }; -#ifdef NGEN_ASM -static inline std::ostream &operator<<(std::ostream &str, Pipe pipe) -{ - static const char *names[8] = {"", "A", "F", "I", "L", "M", "", ""}; - str << names[static_cast(pipe) & 7]; - return str; -} -#endif - class SWSBInfo { friend class InstructionModifier; @@ -2014,10 +1859,6 @@ class Immediate { public: Immediate() : payload(0), type(DataType::invalid) {} -#ifdef NGEN_ASM - static const bool emptyOp = false; -#endif - constexpr14 DataType getType() const { return type; } explicit constexpr14 operator uint64_t() const { return payload; } constexpr14 int getMods() const { return 0; } @@ -2170,10 +2011,6 @@ class Immediate { result.set(int16_t(payload)); return result; } - -#ifdef NGEN_ASM - inline void outputText(std::ostream &str, PrintDetail detail, LabelManager &man) const; -#endif }; // Compute ctrl field for bfn instruction. diff --git a/src/gpu/intel/jit/ngen/ngen_interface.hpp b/src/gpu/intel/jit/ngen/ngen_interface.hpp index 89cb7b390d6..0275007993e 100644 --- a/src/gpu/intel/jit/ngen/ngen_interface.hpp +++ b/src/gpu/intel/jit/ngen/ngen_interface.hpp @@ -137,10 +137,6 @@ class InterfaceHandler inline void generateDummyCL(std::ostream &stream) const; inline std::string generateZeInfo() const; -#ifdef NGEN_ASM - inline void dumpAssignments(std::ostream &stream) const; -#endif - static constexpr int noSurface = 0x80; // Returned by getArgumentSurfaceIfExists in case of no surface assignment protected: @@ -742,25 +738,6 @@ std::string InterfaceHandler::generateZeInfo() const return md.str(); } -#ifdef NGEN_ASM -void InterfaceHandler::dumpAssignments(std::ostream &stream) const -{ - LabelManager manager; - - for (auto &assignment : assignments) { - stream << "// "; - if (assignment.reg.isValid()) - assignment.reg.outputText(stream, PrintDetail::sub, manager); - else - stream << "(none)"; - stream << '\t' << assignment.name; - if (assignment.surface != noSurface) - stream << "\t(BTI " << assignment.surface << ')'; - stream << std::endl; - } -} -#endif - } /* namespace NGEN_NAMESPACE */ #endif /* header guard */ diff --git a/src/gpu/intel/jit/ngen/ngen_register_decl.hpp b/src/gpu/intel/jit/ngen/ngen_register_decl.hpp index f56175be978..47827e56b21 100644 --- a/src/gpu/intel/jit/ngen/ngen_register_decl.hpp +++ b/src/gpu/intel/jit/ngen/ngen_register_decl.hpp @@ -504,11 +504,6 @@ NGEN_REGISTER_DECL_EXTRA4(CG,PREFIX) #include "ngen.hpp" NGEN_REGISTER_DECL(NGEN_NAMESPACE::BinaryCodeGenerator, template ) -#ifdef NGEN_ASM -#include "ngen_asm.hpp" -NGEN_REGISTER_DECL(NGEN_NAMESPACE::AsmCodeGenerator, /* nothing */) -#endif - template class NGEN_NAMESPACE::BinaryCodeGenerator; template class NGEN_NAMESPACE::BinaryCodeGenerator; template class NGEN_NAMESPACE::BinaryCodeGenerator; diff --git a/src/gpu/intel/jit/pass/dpas.cpp b/src/gpu/intel/jit/pass/dpas.cpp index 7c7a04e023f..1621e90e49a 100644 --- a/src/gpu/intel/jit/pass/dpas.cpp +++ b/src/gpu/intel/jit/pass/dpas.cpp @@ -93,8 +93,10 @@ class dpas_atomic_mutator_t : public mul_mutator_t { } }; -stmt_t inject_dpas_atomic(const stmt_t &stmt) { - return dpas_atomic_mutator_t().mutate(stmt); +stmt_t inject_dpas_atomic(const stmt_t &stmt, bool filter_by_label) { + if (filter_by_label) return dpas_atomic_mutator_t().mutate(stmt); + auto ret = dpas_atomic_mutator_t().mutate_mul(stmt); + return ret; } } // namespace jit diff --git a/src/gpu/intel/jit/pass/dpas.hpp b/src/gpu/intel/jit/pass/dpas.hpp index 449b9268792..556524defc1 100644 --- a/src/gpu/intel/jit/pass/dpas.hpp +++ b/src/gpu/intel/jit/pass/dpas.hpp @@ -26,7 +26,7 @@ namespace intel { namespace jit { // Adds {Atomic} modifier to dpas/dpasw instructions when applicable. -stmt_t inject_dpas_atomic(const stmt_t &stmt); +stmt_t inject_dpas_atomic(const stmt_t &stmt, bool filter_by_label = true); } // namespace jit } // namespace intel diff --git a/src/gpu/intel/jit/pass/simplify.cpp b/src/gpu/intel/jit/pass/simplify.cpp index 6c2356d41f7..d8f330c62a7 100644 --- a/src/gpu/intel/jit/pass/simplify.cpp +++ b/src/gpu/intel/jit/pass/simplify.cpp @@ -201,7 +201,7 @@ bool match_binary( match_context_t ctx_copy = ctx; if (match(ptrn_op.a, expr_op.a, ctx_copy) && match(ptrn_op.b, expr_op.b, ctx_copy)) { - ctx = ctx_copy; + ctx = std::move(ctx_copy); return true; } return false; @@ -220,7 +220,7 @@ bool match_iif(const expr_t &ptrn, const expr_t &expr, match_context_t &ctx) { if (match(ptrn_iif.cond, expr_iif.cond, ctx_copy) && match(ptrn_iif.true_expr, expr_iif.true_expr, ctx_copy) && match(ptrn_iif.false_expr, expr_iif.false_expr, ctx_copy)) { - ctx = ctx_copy; + ctx = std::move(ctx_copy); return true; } @@ -697,14 +697,14 @@ class nary_op_visitor_t : public ir_visitor_t { public: using ir_visitor_t::_visit; - virtual void _visit(const nary_op_t &obj) { visit(obj.args); } + void _visit(const nary_op_t &obj) override { visit(obj.args); } }; class nary_op_mutator_t : public ir_mutator_t { public: using ir_mutator_t::_mutate; - virtual object_t _mutate(const nary_op_t &obj) { + object_t _mutate(const nary_op_t &obj) override { auto args = mutate(obj.args); if (ir_utils::is_equal(args, obj.args)) return obj; return make_nary_op(obj.op_kind, args); @@ -730,7 +730,7 @@ class nary_op_transformer_t : public nary_op_mutator_t { b *= -1; } b = mutate(b); - return make_nary_op(nary_op_kind, {a, b}); + return make_nary_op(nary_op_kind, {std::move(a), std::move(b)}); } default: return nary_op_mutator_t::_mutate(obj); } @@ -747,7 +747,7 @@ class nary_op_flattener_t : public nary_op_mutator_t { if (nary && nary->op_kind == obj.op_kind) { args.insert(args.end(), nary->args.begin(), nary->args.end()); } else { - args.push_back(new_a); + args.emplace_back(new_a); } } return make_nary_op(obj.op_kind, args); @@ -788,7 +788,7 @@ class mul_nary_op_expander_t : public nary_op_flattener_t { for (auto &b : i_args) next_args.push_back(cvt_mul_to_nary_op(a, b)); - new_args = next_args; + new_args = std::move(next_args); } return make_nary_op(op_kind_t::_add, new_args); } @@ -868,10 +868,10 @@ bool is_nary_op_canonical(const expr_t &e) { class nary_op_back_transformer_t : public nary_op_mutator_t { public: - object_t _mutate(const nary_op_t &obj) { + object_t _mutate(const nary_op_t &obj) override { auto new_obj = nary_op_mutator_t::_mutate(obj); auto &nary = new_obj.as(); - ir_assert(nary.args.size() > 0) << new_obj; + ir_assert(!nary.args.empty()) << new_obj; if (nary.args.size() == 1) return nary.args[0]; @@ -1233,7 +1233,7 @@ class int_div_mod_expander_t : public nary_op_mutator_t { expr_t ret = reduce_v1(obj); if (!ret.is_empty()) return ret; ret = reduce_v2(obj); - return (!ret.is_empty()) ? ret : obj; + return (!ret.is_empty()) ? std::move(ret) : obj; } // Applies the following rules: @@ -1483,7 +1483,7 @@ class common_factor_simplifier_t : public nary_op_mutator_t { make_nary_op(op_kind_t::_add, {fi.expr(), fj.expr()})); auto &fi_add_fj = e_fi_add_fj.as(); args[i] = make_nary_op(op_kind_t::_mul, fi_add_fj.factors); - e_fi = e_fi_add_fj; + e_fi = std::move(e_fi_add_fj); args[j] = to_expr(0, args[j].type()); } } @@ -1572,7 +1572,7 @@ class stmt_simplifier_t : public ir_mutator_t { auto cset_old = cset_; cset_.add_constraint(cond); body = ir_mutator_t::mutate(body); - cset_ = cset_old; + cset_ = std::move(cset_old); } auto else_body = obj.else_body; @@ -1580,7 +1580,7 @@ class stmt_simplifier_t : public ir_mutator_t { auto cset_old = cset_; cset_.add_constraint(flip_condition(cond)); else_body = ir_mutator_t::mutate(else_body); - cset_ = cset_old; + cset_ = std::move(cset_old); } return if_t::make(cond, body, else_body); @@ -1615,7 +1615,7 @@ class stmt_simplifier_t : public ir_mutator_t { auto cset_old = cset_; cset_.add_constraint(obj.var == value); auto body = mutate(obj.body); - cset_ = cset_old; + cset_ = std::move(cset_old); return let_t::make(obj.var, value, body); } @@ -1627,13 +1627,13 @@ class stmt_simplifier_t : public ir_mutator_t { if (is_one(new_bound) && is_zero(new_init)) { auto body = substitute(obj.body, obj.var, expr_t(0)); body = mutate(body); - new_obj = body; + new_obj = std::move(body); } else { auto cset_old = cset_; cset_.add_constraint(obj.var >= obj.init); cset_.add_constraint(obj.var < obj.bound); new_obj = ir_mutator_t::_mutate(obj); - cset_ = cset_old; + cset_ = std::move(cset_old); } return new_obj; @@ -1896,6 +1896,7 @@ expr_t const_fold_unary(op_kind_t op_kind, const expr_t &a) { if (!a.type().is_scalar()) { int elems = a.type().elems(); std::vector ret; + ret.reserve(elems); for (int i = 0; i < elems; i++) { ret.push_back(const_fold_unary(op_kind, a[i])); } @@ -1924,6 +1925,7 @@ expr_t const_fold_binary(const type_t &compute_type, op_kind_t op_kind, int elems = compute_type.elems(); auto scalar_type = compute_type.scalar(); std::vector ret; + ret.reserve(elems); for (int i = 0; i < elems; i++) { ret.push_back(const_fold_binary(scalar_type, op_kind, a[i], b[i])); } diff --git a/src/gpu/intel/jit/pass/slm.cpp b/src/gpu/intel/jit/pass/slm.cpp index 28eee3aca08..3319f021f27 100644 --- a/src/gpu/intel/jit/pass/slm.cpp +++ b/src/gpu/intel/jit/pass/slm.cpp @@ -128,7 +128,9 @@ class slm_reorder_injector_t : public ir_mutator_t { auto d = dst.map(dst_it.tile()); if (s.is_dense() && d.is_dense() && src_it.outer_layout() == dst_it.outer_layout()) { - if (is_slm_reorder_ok(s, d)) { max_tile = src_tile; } + if (is_slm_reorder_ok(s, d)) { + max_tile = std::move(src_tile); + } } if (!src_it.has_next() || !dst_it.has_next()) break; ++src_it; diff --git a/src/gpu/intel/jit/pooling/ir_builder.cpp b/src/gpu/intel/jit/pooling/ir_builder.cpp index eb2aa02cfb7..6582cc033f6 100644 --- a/src/gpu/intel/jit/pooling/ir_builder.cpp +++ b/src/gpu/intel/jit/pooling/ir_builder.cpp @@ -283,13 +283,13 @@ stmt_t pooling_ir_builder_t::try_build(pooling_ir_builder_t &pb, schedule.split(s0, s0_full, s0_split, s0_ktlg, ps0 + "_split", ps0 + "_ktlg"); s1_fuse.emplace_back(s0_split); - s0 = s0_ktlg; + s0 = std::move(s0_ktlg); } else if (dims[s0_idx] <= utils::div_up(s0_full, 2)) { expr_t s1_split, s1_ktlg; // part of kg[s1] is in kg[s0] const int s1_ext = utils::div_up(s0_full, dims[s0_idx]); schedule.split(s1_fuse[0], s1_ext, s1_ktlg, s1_split, ps1 + "_ktlg", ps1 + "_split"); - s1_fuse[0] = s1_ktlg; + s1_fuse[0] = std::move(s1_ktlg); s0_fuse.emplace_back(s1_split); } @@ -446,7 +446,7 @@ stmt_t pooling_ir_builder_t::try_build(pooling_ir_builder_t &pb, if (is_identity) { allocs.emplace_back(read_alloc[0]); write_stmt = substitute(write_stmt, acc_buf, read_buf); - acc_buf = read_buf; + acc_buf = std::move(read_buf); acc_type = read_type; stmt = (check_idhw) ? gen_zero_out(simd, is_neg, acc_buf, dst_tile, write_layout) @@ -478,8 +478,9 @@ stmt_t pooling_ir_builder_t::try_build(pooling_ir_builder_t &pb, = compute_stmt.append(store_t::make(acc_buf, off_a, op)); }); - stmt = stmt.append(schedule.create_loop_nest( - (check_idhw) ? fill_stmt.append(compute_stmt) : compute_stmt)); + stmt = stmt.append(schedule.create_loop_nest((check_idhw) + ? fill_stmt.append(compute_stmt) + : std::move(compute_stmt))); if (!cfg.is_max()) { expr_t filter(prb.kd * prb.kh * prb.kw); @@ -552,7 +553,7 @@ stmt_t pooling_ir_builder_t::try_build(pooling_ir_builder_t &pb, ir_trace() << "Pooling kernel body:\n" << stmt << std::endl; ir_trace() << "Pooling cfg (~" << regs << " regs):\n" << cfg << std::endl; - return (regs > exec.regs()) ? stmt_t() : stmt; + return (regs > exec.regs()) ? stmt_t() : std::move(stmt); } } // namespace jit diff --git a/src/gpu/intel/jit/reorder/ir_builder.cpp b/src/gpu/intel/jit/reorder/ir_builder.cpp index 9aeb08e32ae..7b1db533193 100644 --- a/src/gpu/intel/jit/reorder/ir_builder.cpp +++ b/src/gpu/intel/jit/reorder/ir_builder.cpp @@ -473,7 +473,7 @@ struct layout_normalization_t { last = s.curr; } blocks.push_back(last); - blocks_ = blocks; + blocks_ = std::move(blocks); } void reindex(int ndims, const std::vector &map) { @@ -627,7 +627,7 @@ void reorder_ir_builder_t::build() { compute_blocks(cfg_.exec_cfg(), src_layout_, dst_layout_, new_iter_blocks, loop_blocks, tg_blocks, cur_iter_bytes); if (!ir_utils::is_equal(new_iter_blocks, iter_blocks)) { - iter_blocks = new_iter_blocks; + iter_blocks = std::move(new_iter_blocks); break; } cur_iter_bytes /= 2; diff --git a/src/gpu/intel/jit/utils/range.hpp b/src/gpu/intel/jit/utils/range.hpp index 9698df73e8c..6726854911d 100644 --- a/src/gpu/intel/jit/utils/range.hpp +++ b/src/gpu/intel/jit/utils/range.hpp @@ -40,7 +40,7 @@ struct filter_range_t { template filter_range_t filter(Fn fn) { - return {fn}; + return {std::move(fn)}; } template diff --git a/src/gpu/intel/jit/utils/utils.cpp b/src/gpu/intel/jit/utils/utils.cpp index 5a90bba772a..598df2a6ee7 100644 --- a/src/gpu/intel/jit/utils/utils.cpp +++ b/src/gpu/intel/jit/utils/utils.cpp @@ -39,10 +39,10 @@ void stringify_to_cpp_file(const std::string &file_name, for (auto &l : lines) { out << " \"" << l << "\",\n"; } - out << " nullptr,\n"; + out << " nullptr,"; out << "\n };"; out << "\n return entries;"; - out << "\n};"; + out << "\n}"; out << "\n// clang-format on\n\n"; for (auto it = namespaces.rbegin(); it != namespaces.rend(); it++) out << "} // namespace " << *it << "\n"; diff --git a/src/gpu/intel/jit/utils/utils.hpp b/src/gpu/intel/jit/utils/utils.hpp index 99ae76da276..ada05e3b638 100644 --- a/src/gpu/intel/jit/utils/utils.hpp +++ b/src/gpu/intel/jit/utils/utils.hpp @@ -681,7 +681,8 @@ ValueT get_or_default(const MapContainerT &map, const KeyT &key, struct debug_profiler_t { #ifdef DNNL_DEV_MODE - debug_profiler_t(std::string profile_name) : profile(profile_name) {}; + debug_profiler_t(const std::string &profile_name) + : profile(profile_name) {}; void start() { profile.start(); }; void stamp(const char *name) { profile.stamp(name); }; void stop(const char *name) { profile.stop(name); }; @@ -692,7 +693,7 @@ struct debug_profiler_t { private: profiler_t profile; #else - debug_profiler_t(std::string) {}; + debug_profiler_t(const std::string &) {}; void start() {}; void stamp(const char *name) {}; void stop(const char *name) {}; @@ -836,7 +837,7 @@ bool stream_try_parse(std::istream &in, T &t) { inline void stream_match(std::istream &in, const std::string &s) { in >> std::ws; for (auto &c : s) { - char next = in.get(); + auto next = in.get(); if (next != c || in.fail()) ir_error_not_expected() << "Cannot match " << s; } @@ -1163,11 +1164,13 @@ class parse_iface_t { } void print_help() const { + std::ios_base::fmtflags f(std::cout.flags()); for (auto &e : entries_) { std::cout << " "; std::cout << std::left << std::setw(22) << e.name; std::cout << e.help << std::endl; } + std::cout.flags(f); } private: diff --git a/src/gpu/intel/jit/v2/conv/bridge.hpp b/src/gpu/intel/jit/v2/conv/bridge.hpp index 16dfff21dbf..d7169468f06 100644 --- a/src/gpu/intel/jit/v2/conv/bridge.hpp +++ b/src/gpu/intel/jit/v2/conv/bridge.hpp @@ -129,6 +129,7 @@ inline problem_t to_problem( prb.set_wei_tag(wei); prb.set_dst_tag(dst); prb.set_shape(shape); + prb.normalize(); return prb; } diff --git a/src/gpu/intel/jit/v2/conv/gen_convolution.cpp b/src/gpu/intel/jit/v2/conv/gen_convolution.cpp index d407776e651..3e79884d9b0 100644 --- a/src/gpu/intel/jit/v2/conv/gen_convolution.cpp +++ b/src/gpu/intel/jit/v2/conv/gen_convolution.cpp @@ -136,21 +136,15 @@ class gen_convolution_t { kernel_params_t _params; if (plan_preset_t::instance().is_set()) { _desc = plan_preset_t::instance().get(); - _desc.hw = hw_t(engine); - _desc.specialize(prb); - { - ir_utils::ir_check_log_level_t check_level(ir_utils::LOG_FATAL); - auto plan = create_conv_plan_and_finalize_desc(_desc); - } } else { auto ®istry = const_plan_registry(); _desc = registry.find_best(prb); - _desc.specialize(prb); } if (_desc.is_empty()) return status::unimplemented; + ir_assert(ir_check_fatal(finalize_conv_desc(_desc, prb))); ir_assert(ir_check_fatal(_desc.fits(prb))); CHECK(init_layouts(_desc, pd)); - _params.prb = prb; + _params.prb = std::move(prb); desc = std::make_shared(_desc); params = std::make_shared(_params); return status::success; diff --git a/src/gpu/intel/jit/v2/conv/ir_builder.cpp b/src/gpu/intel/jit/v2/conv/ir_builder.cpp index 6bb965eb999..0d47756fe50 100644 --- a/src/gpu/intel/jit/v2/conv/ir_builder.cpp +++ b/src/gpu/intel/jit/v2/conv/ir_builder.cpp @@ -49,6 +49,7 @@ class loop_nest_t { const expr_t &size(int level) const { return loops_[level].size; } std::vector indices() const { std::vector ret; + ret.reserve(nloops()); for (int i = 0; i < nloops(); i++) { ret.push_back(index(i)); } @@ -251,7 +252,7 @@ class send_mask_t { expr_t ret; for (auto &e : entries_) { auto cmp = (e.off.load() < e.off.make_broadcast(e.bound)); - ret = (ret.is_empty() ? cmp : (ret & cmp)); + ret = (ret.is_empty() ? std::move(cmp) : (ret & cmp)); if (e.has_underflow) ret &= (e.off.load() >= e.off.make_broadcast(0)); } @@ -390,7 +391,7 @@ class offset_scope_t { ret.type = type; ret.base = base0 + _base_init; ret.shift = _shift; - ret.shift_vec = shift_vec; + ret.shift_vec = std::move(shift_vec); ret.esize = params.esize; expr_t comp_value = 0; @@ -559,7 +560,7 @@ class iterator_t { if (i + 1 < nloops()) stmt = stmt.append(if_t::make( loop_idxs_[i].var() >= loop_nest_.size(i), body)); - body = stmt; + body = std::move(stmt); } body = linear_idx_.inc_stmt(-1).append(body); return body; @@ -761,8 +762,8 @@ loop_nest_t make_loop_nest( const loop_desc_t &loop_desc, const coord_info_t &coord_info) { loop_nest_t ret; for (auto &e : loop_desc) { - auto index = coord_info.loop_index(e.dim); - auto size = coord_info.loop_size(e.dim); + const auto &index = coord_info.loop_index(e.dim); + const auto &size = coord_info.loop_size(e.dim); ret.add_loop(index, size); } return ret; @@ -847,8 +848,8 @@ class ir_builder_t { ret = ret.append(prefetch_it.inc_stmt(prefetch_off_ctx_)); } for (auto &e : loop_desc) { - auto var = coord_info.loop_index(e.dim); - auto bound = coord_info.loop_size(e.dim); + const auto &var = coord_info.loop_index(e.dim); + const auto &bound = coord_info.loop_size(e.dim); ret = ret.append(off_ctx_.inc_loop_stmt(e.idx)); ret = for_t::make(var, 0, bound, ret); } @@ -882,7 +883,7 @@ class ir_builder_t { auto &coord_info = plan_.coord_info; stmt_t ret = stmt; for (auto &d : conv_index_dims(plan_.desc.prop)) { - auto tg_idx = coord_info.tg_index(d); + const auto &tg_idx = coord_info.tg_index(d); if (is_const(tg_idx)) continue; auto base_tg_idx = tg_grid.index_var(d); if (base_tg_idx.is_empty()) continue; @@ -906,7 +907,7 @@ class ir_builder_t { auto base_idx = tg_grid.index_var(dim); if (base_idx.is_empty()) return expr_t(); - expr_t value = base_idx; + expr_t value = std::move(base_idx); auto &dims = tg_grid.dims(tg_grid.index(dim)); int ndims = (int)dims.size(); for (int i = 0; i < ndims; i++) { @@ -1024,13 +1025,13 @@ class ir_builder_t { auto src1 = a_buf[a_off]; auto src2 = b_buf[b_off]; if (fma.fma == fma_kind_t::dpas) std::swap(src1, src2); - stmt = stmt.append( - fma_func.call({dst, dst, src1, src2})); + stmt = stmt.append(fma_func.call( + {dst, dst, std::move(src1), std::move(src2)})); } } } } - stmt = inject_dpas_atomic(stmt); + stmt = inject_dpas_atomic(stmt, /*filter_by_label=*/false); x2r_mul_stmt_ = x2r_mul_stmt_.append(stmt); } @@ -1067,7 +1068,7 @@ class ir_builder_t { auto &bia_buf = buf_mgr_.find_buf("bia_buf"); auto bia_tile = epilogue.bia_store.reg_layout().int_dim_sizes(); auto epilogue_tile = bia_tile; - for (auto d : bia_tile) + for (auto &d : bia_tile) epilogue_tile[d] = epilogue.tile[d]; for_each(bia_tile, epilogue_tile, [&](const prb_coord_t &coord) { auto bia_payload_buf = bia_buf; @@ -1080,7 +1081,7 @@ class ir_builder_t { auto stmt = create_stmt( epilogue.bia_reorder, bia_buf + src_off, bia_tmp_buf); epilogue_stmt_ = epilogue_stmt_.append(stmt); - bia_payload_buf = bia_tmp_buf; + bia_payload_buf = std::move(bia_tmp_buf); bia_payload_layout = epilogue.bia_reorder.dst; payload_coord = prb_coord_t(); } @@ -1099,7 +1100,7 @@ class ir_builder_t { auto &c_buf = buf_mgr_.find_buf("c"); auto c_tmp_buf = buf_mgr_.get("c_reduce", slm_reduce.load.reg_layout().size()); - auto c_slm_buf = buf_mgr_.get("slm", slm_reduce.slm_size()); + auto c_slm_buf = buf_mgr_.get("slm", slm_reduce.slm_usage_bytes()); auto store_stmt = create_stmt( slm_reduce.store, c_slm_buf, c_buf, epilogue_off_ctx_); auto load_stmt = create_stmt( @@ -1130,7 +1131,7 @@ class ir_builder_t { auto stmt = create_stmt( epilogue.reorder, c_buf + src_off, c_tmp_buf); epilogue_stmt_ = epilogue_stmt_.append(stmt); - payload_buf = c_tmp_buf; + payload_buf = std::move(c_tmp_buf); payload_layout = epilogue.reorder.dst; payload_coord = prb_coord_t(); } @@ -1164,7 +1165,7 @@ class ir_builder_t { break; default: ir_error_not_expected(); } - return kernel_info_.find_arg(name.c_str()); + return kernel_info_.find_arg(name); } expr_t a_mem_buf() const { return mem_buf(tensor_kind_t::a); } diff --git a/src/gpu/intel/jit/v2/conv/kernel_desc.cpp b/src/gpu/intel/jit/v2/conv/kernel_desc.cpp index cbe54295410..a781c317e99 100644 --- a/src/gpu/intel/jit/v2/conv/kernel_desc.cpp +++ b/src/gpu/intel/jit/v2/conv/kernel_desc.cpp @@ -76,6 +76,43 @@ void store_desc_t::parse(std::istream &in) { } } +std::string align_desc_t::align_t::str() const { + std::string s = (value == 0 ? "*" : std::to_string(value)); + if (in_bytes) s += "b"; + return s; +} + +void align_desc_t::align_t::parse(const std::string &_s) { + auto s = _s; + in_bytes = (!s.empty() && s.back() == 'b'); + if (in_bytes) s = s.substr(0, s.length() - 1); + value = (s == "*") ? 0 : std::stoi(s); +} + +std::string align_desc_t::str() const { + if (is_default()) return "x"; + std::vector parts; + parts.emplace_back(src.str()); + parts.emplace_back(wei.str()); + parts.emplace_back(dst.str()); + return gpu_utils::join(":", parts); +} + +void align_desc_t::parse(std::istream &in) { + operator=(align_desc_t()); + auto s = jit::parse(in); + if (s == "x") return; + auto parts = gpu_utils::split(s, ":"); + if (parts.size() == 1) { + parts.push_back(parts[0]); + parts.push_back(parts[0]); + } + ir_assert(parts.size() == 3); + src.parse(parts[0]); + wei.parse(parts[1]); + dst.parse(parts[2]); +} + void prefetch_desc_t::parse(std::istream &in) { operator=(prefetch_desc_t()); std::string s; @@ -378,9 +415,9 @@ void kernel_desc_t::set_defaults() { } } -void kernel_desc_t::finalize(const plan_t &plan) { +void kernel_desc_t::finalize(const prb_reqs_t &final_reqs) { is_finalized = true; - reqs.add(plan.reqs()); + reqs.add(final_reqs); } std::string kernel_desc_t::cmd_str() const { @@ -407,8 +444,11 @@ std::string kernel_desc_t::str() const { oss << "Thread group tile: " << thread_group_tile << std::endl; oss << "Loop desc: " << loop_desc << std::endl; oss << "Load: " << load.str() << std::endl; - oss << "Prefetch: " << prefetch.str() << std::endl; oss << "Store: " << store.str() << std::endl; + oss << "Use block 2D access: " << ir_utils::to_string(use_2d_access) + << std::endl; + oss << "Align: " << align.str() << std::endl; + oss << "Prefetch: " << prefetch.str() << std::endl; if (reqs) oss << ir_utils::add_tag("Reqs", reqs.str()) << std::endl; oss << "Command: " << cmd_str(); return ir_utils::add_tag("Desc", oss.str()); @@ -444,21 +484,27 @@ void kernel_desc_t::init_parse_iface(parse_iface_t *iface) { "tg", "Threadgroup tile (e.g. ow4oc4).", /*required=*/true); iface->add("loop_desc", "Loop description, variables ordered from innermost to outermost " - "(e.g. kw,kh,kd,ic).", - /*required=*/true); + "(e.g. kw,kh,kd,ic)."); iface->add("load", "Load type (block, scattered [default], 2d) for A and B, e.g. " "a:2d,b:block."); iface->add("store", "Store type (block, scattered [default], 2d) for C, e.g. c:2d."); + iface->add( + "2d", "Whether to use block 2D messages for access."); + iface->add("align", + "Alignments in bytes/elements for the innermost dimension in " + "source, weights and destination. Examples: 8b:8b:8b (in bytes), " + "2:2:2 (in elements), *:*:* (for optimal values determined during " + "kernel plan generation)."); iface->add("prefetch", "Prefetch description specifying distance and whether A/B are " "prefetched. Examples: x3 (distance is 3, both A/B are " "prefetched), x2.a (distance is 2, only A is prefetched), x0 (no " "prefetch, default)."); iface->add("spec_strategy", - "Specialization strategy for problem dimensions (e.g. 1d for 1D " - "convolution)."); + "Specialization strategy for problem dimensions (e.g. min_dims to " + "eliminate unused spatial dimensions)."); iface->add("reqs", "Dimension requirements, colon-separated (e.g. kd=1:mb>=16)."); #undef PACK diff --git a/src/gpu/intel/jit/v2/conv/kernel_desc.hpp b/src/gpu/intel/jit/v2/conv/kernel_desc.hpp index 4476f9a225c..17560fd56cb 100644 --- a/src/gpu/intel/jit/v2/conv/kernel_desc.hpp +++ b/src/gpu/intel/jit/v2/conv/kernel_desc.hpp @@ -45,7 +45,7 @@ namespace v2 { namespace conv { struct hw_desc_t { - ngen::HW hw; + ngen::HW hw = ngen::HW::Unknown; void stringify(std::ostream &out) const { jit::stringify(out, hw); } void parse(std::istream &in) { jit::parse(in, hw); } @@ -162,6 +162,43 @@ class loop_desc_t { std::vector entries_; }; +enum class access_mode_t { + // Rely on explicit load/store settings. + direct, + // Rely on alignment/2D settings + alignment, +}; + +struct align_desc_t { + struct align_t { + int value = 0; + // If true, then value in bytes, otherwise in elements. + bool in_bytes = false; + + bool is_default() const { return value == 0; } + std::string str() const; + void parse(const std::string &s); + }; + align_t src; + align_t wei; + align_t dst; + + bool is_default() const { + return src.is_default() && wei.is_default() && dst.is_default(); + } + + std::string str() const; + + IR_DEFINE_DUMP() + +#if __cplusplus >= 202002L + bool operator==(const align_desc_t &other) const = default; +#endif + + void stringify(std::ostream &out) const { out << str(); } + void parse(std::istream &in); +}; + struct load_desc_t { send_kind_t a = send_kind_t::undef; send_kind_t b = send_kind_t::undef; @@ -256,8 +293,16 @@ class kernel_desc_t : public kernel_desc_base_t { prb_tile_t iter_outer_tile; prb_tile_t thread_group_tile; loop_desc_t loop_desc; + + access_mode_t access_mode = access_mode_t::direct; + // For direct mode. load_desc_t load; store_desc_t store; + + // For alignment-based/2D mode. + bool use_2d_access = false; + align_desc_t align; + prefetch_desc_t prefetch; prb_reqs_t reqs; @@ -268,7 +313,7 @@ class kernel_desc_t : public kernel_desc_base_t { bool is_supported() const; void set(const std::string &s); void set_defaults(); - void finalize(const plan_t &plan); + void finalize(const prb_reqs_t &final_reqs); bool fits(const problem_t &prb, bool check_tags = true) const { ir_check(prb.prop() == prop) << "Propagation kind does not match"; @@ -311,24 +356,29 @@ class kernel_desc_t : public kernel_desc_base_t { } send_kind_t access_kind(send_op_t op, tensor_kind_t tensor) const { - switch (op) { - case send_op_t::load: - switch (tensor) { - case tensor_kind_t::a: return load.a; - case tensor_kind_t::b: return load.b; - default: ir_error_not_expected(); - } - break; - case send_op_t::store: - switch (tensor) { - case tensor_kind_t::c: return store.c; - case tensor_kind_t::bia: return send_kind_t::undef; - default: ir_error_not_expected(); - } - break; - default: ir_error_not_expected(); + if (access_mode == access_mode_t::direct) { + switch (op) { + case send_op_t::load: + switch (tensor) { + case tensor_kind_t::a: return load.a; + case tensor_kind_t::b: return load.b; + default: ir_error_not_expected(); + } + break; + case send_op_t::store: + switch (tensor) { + case tensor_kind_t::c: return store.c; + case tensor_kind_t::bia: return send_kind_t::undef; + default: ir_error_not_expected(); + } + break; + default: ir_error_not_expected(); + } + return send_kind_t::undef; + } else { + if (use_2d_access) return send_kind_t::_2d; + return send_kind_t::undef; } - return send_kind_t::undef; } std::string kernel_name() const override { return "gen_conv_v2"; } @@ -354,9 +404,16 @@ class kernel_desc_t : public kernel_desc_base_t { void specialize(const problem_t &prb) { if (!has_spec_strategy()) return; switch (spec_strategy) { - case spec_strategy_t::max: reqs.add(prb.shape()); break; - case spec_strategy_t::min_dims: reqs.add(min_dims_tile(prb)); break; - default: break; + case spec_strategy_t::max: + reqs.add(prb.shape()); + reqs.simplify(); + break; + case spec_strategy_t::min_dims: + reqs.add(min_dims_tile(prb)); + reqs.simplify(); + break; + case spec_strategy_t::none: break; + default: ir_error_not_expected(); } spec_strategy = spec_strategy_t::none; } @@ -480,6 +537,8 @@ struct trivial_key_validator_t { && (t.thread_group_tile == tmp.thread_group_tile) && (t.loop_desc == tmp.loop_desc) && (t.load == tmp.load) && (t.prefetch == tmp.prefetch) && (t.store == tmp.store) + && (t.align == tmp.align) + && (t.use_2d_access == tmp.use_2d_access) && (t.is_finalized == tmp.is_finalized); } }; diff --git a/src/gpu/intel/jit/v2/conv/model.cpp b/src/gpu/intel/jit/v2/conv/model.cpp index daff3efefa4..c5f4dd6f5fd 100644 --- a/src/gpu/intel/jit/v2/conv/model.cpp +++ b/src/gpu/intel/jit/v2/conv/model.cpp @@ -118,7 +118,7 @@ struct sample_t { uint64_t time_ns = 0) : prb(prb), kernel_desc(kernel_desc), time_ns(time_ns) { hw_cfg = hw_config_t( - kernel_desc.hw, kernel_desc.fma, kernel_desc.src_tag.type()); + prb.hw(), kernel_desc.fma, kernel_desc.src_tag.type()); prb_tile_t padded_shape = prb.shape(); pad_eff = 1; for (auto &d : padded_shape) { @@ -229,9 +229,9 @@ void model_t::parse(std::istream &in) { std::vector data; auto s_data = stream_parse(in); for (size_t i = 0; i < s_data.size(); i += 2) { - data.push_back(std::stoi(s_data.substr(i, 2), 0, 16)); + data.push_back(std::stoi(s_data.substr(i, 2), nullptr, 16)); } - auto s = serialized_t::from_data(data); + auto s = serialized_t::from_data(std::move(data)); deserializer_t d(s); ml_model_ = ml_model_t::deserialize(d); } diff --git a/src/gpu/intel/jit/v2/conv/plan.cpp b/src/gpu/intel/jit/v2/conv/plan.cpp index 171606bb2a4..e7ef3540a6a 100644 --- a/src/gpu/intel/jit/v2/conv/plan.cpp +++ b/src/gpu/intel/jit/v2/conv/plan.cpp @@ -347,7 +347,7 @@ class multiply_info_t { case fma_kind_t::dpas: { auto a_tile = a_inner_.int_dim_sizes(); auto b_tile = b_inner_.int_dim_sizes(); - ret = a_tile; + ret = std::move(a_tile); for (auto &d : b_tile) { if (ret.has(d)) ir_assert(ret[d] == b_tile[d]); ret[d] = b_tile[d]; @@ -481,7 +481,7 @@ class multiply_info_t { block.dim = d; block.size = simd_; block.stride = expr_t(1); - b_inner_ = layout_t(b_desc, b_type_, 0, {block}); + b_inner_ = layout_t(b_desc, b_type_, 0, {std::move(block)}); break; } } @@ -566,7 +566,12 @@ class multiply_info_t { class plan_builder_t { public: plan_builder_t() = default; - plan_builder_t(const kernel_desc_t &desc) : desc_(desc) {} + plan_builder_t(const kernel_desc_t &desc) : desc_(desc) { + reqs_ = desc_.reqs; + desc_.reqs = prb_reqs_t(); + } + + const prb_reqs_t &reqs() const { return reqs_; } plan_t build() { init_dim_mapper_manager(); @@ -590,8 +595,21 @@ class plan_builder_t { return plan; } + void add_align_req(const prb_dim_t &dim, const type_t &type, + const align_desc_t::align_t &align) { + if (align.value == 0) { + reqs_.set_any_mod(dim); + } else { + int align_bytes = (align.in_bytes ? align.value + : align.value * type.size()); + reqs_.add( + size_var(dim) % ir_utils::safe_div(align_bytes, type.size()) + == 0); + } + } + void init_dim_mapper_manager() { - dim_mapper_manager_ = dim_mapper_manager_t(desc_.prop, desc_.reqs); + dim_mapper_manager_ = dim_mapper_manager_t(desc_.prop, reqs_); } void init_tiles() { @@ -604,23 +622,32 @@ class plan_builder_t { int iter_tile = desc_.iter_tile.get(d, 1); auto thr_idx = thr_grid_.index_var(d); coord_info_.add_dim(d, is_loop, is_global_loop, tg_tile, thr_idx, - iter_tile, desc_.reqs); + iter_tile, reqs_); } } void init_layouts() { auto src_layout = make_conv_layout( - tensor_kind_t::src, desc_.src_tag, desc_.is_dw, desc_.reqs); + tensor_kind_t::src, desc_.src_tag, desc_.is_dw, reqs_); auto wei_layout = make_conv_layout( - tensor_kind_t::wei, desc_.wei_tag, desc_.is_dw, desc_.reqs); + tensor_kind_t::wei, desc_.wei_tag, desc_.is_dw, reqs_); auto dst_layout = make_conv_layout( - tensor_kind_t::dst, desc_.dst_tag, desc_.is_dw, desc_.reqs); + tensor_kind_t::dst, desc_.dst_tag, desc_.is_dw, reqs_); a_layout_ = pick_a(desc_.prop, src_layout, wei_layout, dst_layout); b_layout_ = pick_b(desc_.prop, src_layout, wei_layout, dst_layout); c_layout_ = pick_c(desc_.prop, src_layout, wei_layout, dst_layout); if (desc_.prop == prop_kind::backward_weights && desc_.with_bias) bia_layout_ = make_conv_layout( - tensor_kind_t::bia, desc_.bia_tag, desc_.is_dw, desc_.reqs); + tensor_kind_t::bia, desc_.bia_tag, desc_.is_dw, reqs_); + if (desc_.access_mode == access_mode_t::alignment) { + auto &align = desc_.align; + add_align_req( + src_layout.blocks()[0].dim, src_layout.type(), align.src); + add_align_req( + wei_layout.blocks()[0].dim, wei_layout.type(), align.wei); + add_align_req( + dst_layout.blocks()[0].dim, dst_layout.type(), align.dst); + } } dim_map_t to_bmnk_map() const { @@ -650,7 +677,7 @@ class plan_builder_t { if (!try_init_plan(plan)) return plan_t(); if (!check_plan(plan)) return plan_t(); - reqs_ = plan.reqs(); + reqs_.add(plan.reqs()); plan = plan_t(desc_.hw); if (!try_init_plan(plan) || !check_plan(plan)) { ir_error_not_expected(); @@ -668,7 +695,8 @@ class plan_builder_t { ir_check(init_x2r_fma_plan(plan.x2r_fma)); ir_check(init_prefetch_plan( plan.x2r_fma, plan.virt_grid, plan.prefetch)); - ir_check(init_epilogue_plan(plan.x2r_fma.c_layout, plan.epilogue)); + ir_check(init_epilogue_plan( + plan.x2r_fma.c_layout, plan.virt_grid, plan.epilogue)); if (desc_.prop == prop_kind::backward_weights && desc_.with_bias) ir_check(init_epilogue_bia(plan.x2r_fma.bia_layout, plan.epilogue)); return true; @@ -728,19 +756,19 @@ class plan_builder_t { if (mul_info_.is_compatible(abc, load.reg_layout())) { reg_layout = load.reg_layout(); } else { - auto src = load.reg_layout(); + auto &src = load.reg_layout(); auto dst = mul_info_.to_compatible_layout(abc, load.reg_layout()); reorder = reorder_plan_t(desc_.hw, src, dst); reg_layout = reorder.dst; } plan = x2r_plan_t(desc_.hw); plan.tensor_kind = abc; - plan.load = load; - plan.reorder = reorder; - plan.layout = reg_layout; + plan.load = std::move(load); + plan.reorder = std::move(reorder); + plan.layout = std::move(reg_layout); if (abc == tensor_kind_t::b) { auto bia_layout = mul_info_.bia_layout(plan.layout, bia_layout_); - plan.bia_layout = bia_layout; + plan.bia_layout = std::move(bia_layout); } return true; } @@ -755,8 +783,8 @@ class plan_builder_t { plan.fma = desc_.fma; plan.a_layout = a; plan.b_layout = b; - plan.c_layout = acc_layout; - plan.inst_tile = inst_tile; + plan.c_layout = std::move(acc_layout); + plan.inst_tile = std::move(inst_tile); return true; } @@ -843,7 +871,7 @@ class plan_builder_t { reduce_cond = reduce_cond & (coord_info_.iter_coord()[dim] == 0); } - plan.reduce_cond = reduce_cond; + plan.reduce_cond = std::move(reduce_cond); auto bia_params = get_send_params( tensor_kind_t::bia, send_op_t::store, bia_iter_view); auto bia_store = create_send_plan(bia_params, bia_iter_view); @@ -854,15 +882,15 @@ class plan_builder_t { auto store_layout = bia_store.reg_layout().map(tile); if (fma_layout != store_layout) { plan.bia_reorder = reorder_plan_t(desc_.hw); - plan.bia_reorder.src = fma_layout; - plan.bia_reorder.dst = store_layout; + plan.bia_reorder.src = std::move(fma_layout); + plan.bia_reorder.dst = std::move(store_layout); } } return true; } - bool init_slm_reduce_plan( - const layout_t &c_layout, slm_reduce_plan_t &plan) const { + bool init_slm_reduce_plan(const layout_t &c_layout, virt_grid_t &virt_grid, + slm_reduce_plan_t &plan) const { prb_dim_t k_dim; for (auto &d : desc_.thread_group_tile) { if (to_gemm(d, desc_.prop) == prb_dim_kind_t::k) { @@ -906,39 +934,43 @@ class plan_builder_t { grid_splitter.add(thr_grid_.index_var(k_dim), k_tg); auto split_view = view_t::split( mapper, c_layout, prb_coord_t(), c_tile, grid_splitter); - ir_assert(grid_splitter.virt_grid_idxs().empty()); + for (auto &kv : grid_splitter.virt_grid_idxs()) { + virt_grid.add(kv.first, kv.second); + } - auto load_coord = split_view.coord(); + auto &load_coord = split_view.coord(); auto tile_with_k = split_view.tile(); tile_with_k[k_dim] = k_tg; // Load partial sums and do the final reduction. - auto load_view = view_t(mapper, slm_layout, load_coord, tile_with_k); + auto load_view = view_t(mapper, slm_layout, load_coord, tile_with_k, + grid_splitter.var_range_info()); auto load_params = get_send_params(tensor_kind_t::c, send_op_t::load, load_view, send_kind_t::block, send_address_t::slm); load_params.skip_mask.push_back(k_dim); auto load = try_create_send_plan(__func__, load_params, load_view); if (!load) return false; - auto load_layout = load.reg_layout(); + auto &load_layout = load.reg_layout(); auto reduced_layout = load_layout.map(split_view.tile()); auto reduce = reduce_plan_t(desc_.hw, load_layout, reduced_layout); - auto c_post_layout = reduced_layout; + auto c_post_layout = std::move(reduced_layout); c_post_layout.remove(k_dim); plan = slm_reduce_plan_t(desc_.hw); - plan.store = store; - plan.load = load; - plan.reduce = reduce; - plan.c_layout = c_post_layout; + plan.store = std::move(store); + plan.load = std::move(load); + plan.reduce = std::move(reduce); + plan.c_layout = std::move(c_post_layout); plan.c_coord = coord_info_.iter_coord() + load_coord; return true; } - bool init_epilogue_plan( - const layout_t &c_fma_layout, epilogue_plan_t &plan) const { - ir_check(init_slm_reduce_plan(c_fma_layout, plan.slm_reduce)); + bool init_epilogue_plan(const layout_t &c_fma_layout, + virt_grid_t &virt_grid, epilogue_plan_t &plan) const { + ir_check( + init_slm_reduce_plan(c_fma_layout, virt_grid, plan.slm_reduce)); auto &c_mapper = dim_mapper_manager_.mapper(tensor_kind_t::c); auto c_reg_layout = (plan.slm_reduce ? plan.slm_reduce.c_layout : c_fma_layout); @@ -960,24 +992,32 @@ class plan_builder_t { } auto params = get_send_params( tensor_kind_t::c, send_op_t::store, c_mem_view); - auto c_store = create_send_plan(params, c_mem_view); - auto tile = c_store.entry_tile(); + // TODO: Implement fallback from 2D to block/scattered messages to + // allow partial use of 2D messages when possible. + auto c_store = try_create_send_plan(__func__, params, c_mem_view); + if (!c_store) return false; + auto &tile = c_store.entry_tile(); plan.tile = tile; plan.c_store = c_store; auto c_reg_tile_layout = c_reg_layout.map(tile); auto store_layout = c_store.reg_layout().map(tile); if (c_reg_tile_layout != store_layout) { plan.reorder = reorder_plan_t(desc_.hw); - plan.reorder.src = c_reg_tile_layout; - plan.reorder.dst = store_layout; + plan.reorder.src = std::move(c_reg_tile_layout); + plan.reorder.dst = std::move(store_layout); } return true; } bool check_plan(const plan_t &plan) const { - int bound = desc_.hw.grf_size() * desc_.regs; - int usage_bytes = plan.grf_usage_bytes(); - ir_check(usage_bytes <= bound) << "check_plan: out of registers"; + int grf_bound = desc_.hw.grf_size() * desc_.regs; + int grf_bytes = plan.grf_usage_bytes(); + ir_check(grf_bytes <= grf_bound) << "check_plan: out of registers"; + int slm_bound = compute::device_info_t::max_slm_size_per_tg( + convert_ngen_arch_to_dnnl(desc_.hw.to_ngen()), + desc_.thread_group_tile.elems(), desc_.regs > 128); + int slm_bytes = plan.slm_usage_bytes(); + ir_check(slm_bytes <= slm_bound) << "check_plan: out of SLM"; return true; } @@ -995,6 +1035,7 @@ class plan_builder_t { params.hint_2d = send_2d_hint_t(view, op, mul_info_.hint(abc)); params.skip_mask = skip_mask(view); params.init_max_entry_reg_size(); + params.external_reqs = &reqs_; return params; } @@ -1049,7 +1090,6 @@ class plan_builder_t { prb_reqs_t plan_t::reqs() const { prb_reqs_t ret; - ret.add(desc.reqs); ret.add(prefetch.reqs()); ret.add(x2r_fma.reqs()); ret.add(epilogue.c_store.reqs()); @@ -1057,20 +1097,35 @@ prb_reqs_t plan_t::reqs() const { return ret; } -plan_t create_conv_plan(const kernel_desc_t &desc) { +template +plan_t create_conv_plan_impl(KernelDescT &desc, bool finalize) { if (!desc.is_supported()) return plan_t(); ir_assert(!desc.has_spec_strategy()) << "Kernel descriptor strategies are required to be specialized " "before plan creation"; plan_builder_t builder(desc); auto plan = builder.build(); + if (plan) { + if (finalize) { + const_cast(desc).finalize(builder.reqs()); + } else { + ir_assert(desc.reqs.implies(builder.reqs())); + } + } return plan; } -plan_t create_conv_plan_and_finalize_desc(kernel_desc_t &desc) { - auto plan = create_conv_plan(desc); - if (plan) desc.finalize(plan); - return plan; +plan_t create_conv_plan(const kernel_desc_t &desc) { + return create_conv_plan_impl(desc, /*finalize=*/false); +} + +bool finalize_conv_desc(kernel_desc_t &desc, const problem_t &prb) { + ir_assert(desc.hw_desc.hw == prb.hw().to_ngen()); + desc.specialize(prb); + desc.hw = prb.hw(); + if (desc.is_finalized) return true; + auto plan = create_conv_plan_impl(desc, /*finalize=*/true); + return (bool)plan; } } // namespace conv diff --git a/src/gpu/intel/jit/v2/conv/plan.hpp b/src/gpu/intel/jit/v2/conv/plan.hpp index a2ba8e0ac75..0b50b2d3942 100644 --- a/src/gpu/intel/jit/v2/conv/plan.hpp +++ b/src/gpu/intel/jit/v2/conv/plan.hpp @@ -245,7 +245,7 @@ struct prefetch_plan_t : public base_plan_t { }; struct x2r_plan_t : public base_plan_t { - tensor_kind_t tensor_kind; + tensor_kind_t tensor_kind = tensor_kind_t::undef; send_plan_t load; reorder_plan_t reorder; layout_t layout; @@ -399,7 +399,7 @@ struct slm_reduce_plan_t : public base_plan_t { return ret; } - int slm_size() const { + int slm_usage_bytes() const { if (!*this) return 0; int k_local = ir_utils::safe_div(reduce.src.elems(), reduce.dst.elems()); @@ -431,7 +431,7 @@ struct epilogue_plan_t : public base_plan_t { using base_plan_t::base_plan_t; int grf_usage_bytes() const { return 0; } - int slm_size() const { return slm_reduce.slm_size(); } + int slm_usage_bytes() const { return slm_reduce.slm_usage_bytes(); } std::string str() const { if (!*this) return "(empty)"; @@ -474,9 +474,9 @@ struct plan_t : public base_plan_t { return ret; } - int slm_size() const { + int slm_usage_bytes() const { int ret = 0; - ret += epilogue.slm_size(); + ret += epilogue.slm_usage_bytes(); return ret; } @@ -495,7 +495,7 @@ struct plan_t : public base_plan_t { }; plan_t create_conv_plan(const kernel_desc_t &desc); -plan_t create_conv_plan_and_finalize_desc(kernel_desc_t &desc); +bool finalize_conv_desc(kernel_desc_t &desc, const problem_t &prb); } // namespace conv } // namespace v2 diff --git a/src/gpu/intel/jit/v2/conv/plan_registry.cpp b/src/gpu/intel/jit/v2/conv/plan_registry.cpp index 7ef16b60de4..9f05dcf1d13 100644 --- a/src/gpu/intel/jit/v2/conv/plan_registry.cpp +++ b/src/gpu/intel/jit/v2/conv/plan_registry.cpp @@ -73,17 +73,18 @@ kernel_desc_t plan_registry_t::find_best(const problem_t &prb) const { if (eff > best_eff) { best_eff = eff; best = e.desc; - best.set_defaults(); } } - best.hw = prb.hw(); + best.spec_strategy = spec_strategy_t::min_dims; return best; } void plan_registry_t::stringify(std::ostream &out) const { + bool is_first = true; for (auto &e : entries_) { + if (!is_first) out << "\n"; e.stringify(out); - out << "\n"; + is_first = false; } } @@ -115,17 +116,20 @@ struct plan_registry_instance_t { } plan_registry_instance_t() { - registry = plan_registry_t(get_plan_registry_entries()); #ifdef DNNL_DEV_MODE registry_path = getenv_string_user(env_registry_path_name); if (!registry_path.empty()) { std::ifstream in(registry_path); - if (!in.good()) return; - plan_registry_t file_registry; - file_registry.parse(in); - registry.merge(file_registry); + if (in.good()) { + registry.parse(in); + ir_info() << "Loaded kernel registry from " << registry_path + << " with " << registry.size() << " entries" + << std::endl; + return; + } } #endif + registry = plan_registry_t(get_plan_registry_entries()); } void dump() const { diff --git a/src/gpu/intel/jit/v2/conv/plan_registry.hpp b/src/gpu/intel/jit/v2/conv/plan_registry.hpp index 6966c12e6c9..de1d5848289 100644 --- a/src/gpu/intel/jit/v2/conv/plan_registry.hpp +++ b/src/gpu/intel/jit/v2/conv/plan_registry.hpp @@ -48,6 +48,7 @@ class plan_registry_t { void set(const kernel_desc_t &desc, const model_t &model) { entries_.emplace_back(desc, model); } + int size() const { return (int)entries_.size(); } void merge(const plan_registry_t &other); kernel_desc_t find_best(const problem_t &prb) const; void stringify(std::ostream &out) const; diff --git a/src/gpu/intel/jit/v2/conv/planner/search.cpp b/src/gpu/intel/jit/v2/conv/planner/search.cpp index daa9b8d4a11..acabe4c811b 100644 --- a/src/gpu/intel/jit/v2/conv/planner/search.cpp +++ b/src/gpu/intel/jit/v2/conv/planner/search.cpp @@ -438,7 +438,7 @@ void auto_search(const bench_manager_t &bench_mger) { }; // clang-format on for (const char *_r : recipes) { - auto r = std::string(_r) + "--iter x --tg x"; + auto r = std::string(_r) + " --iter x --tg x"; kernel_desc_t desc; desc.set(r); desc.hw = hw_t(bench_mger.get_engine().get()); diff --git a/src/gpu/intel/jit/v2/conv/problem.cpp b/src/gpu/intel/jit/v2/conv/problem.cpp index 73ac71153fb..107b3f3c347 100644 --- a/src/gpu/intel/jit/v2/conv/problem.cpp +++ b/src/gpu/intel/jit/v2/conv/problem.cpp @@ -59,7 +59,20 @@ void problem_t::set_shape(const std::string &s) { if (s_tile.has(d)) continue; s_tile.set(d, default_shape()[d]); } - shape_ = s_tile; + shape_ = std::move(s_tile); +} + +double problem_t::ops() const { + return ops(prop_, shape_); +} + +void problem_t::normalize() { +#define GET(name) shape_[prb_dims::name] + normalize_conv_shape(GET(od), GET(id), GET(kd), GET(sd), GET(dd), GET(pd), + GET(oh), GET(ih), GET(kh), GET(sh), GET(dh), GET(ph), GET(ow), + GET(iw), GET(kw), GET(sw), GET(dw), GET(pw), + /*can_flatten_spatial=*/true, dhw_map_); +#undef GET } std::string problem_t::desc_str() const { @@ -173,6 +186,20 @@ prb_tile_t problem_t::default_shape() { return _default_shape; } +double problem_t::ops(prop_kind_t prop, const prb_tile_t &shape) { +#define GET(name) shape[prb_dims::name] + double ret = 2.0; + ret *= (double)GET(g) * GET(mb) * GET(oc) * GET(ic); + ret *= GET(kd) * GET(kh) * GET(kw); + if (prop == prop_kind::backward_data) { + ret *= GET(id) * GET(ih) * GET(iw); + } else { + ret *= GET(od) * GET(oh) * GET(ow); + } +#undef GET + return ret; +} + class arg_helper_t { public: arg_helper_t(prop_kind_t prop, bool with_bias) diff --git a/src/gpu/intel/jit/v2/conv/problem.hpp b/src/gpu/intel/jit/v2/conv/problem.hpp index d52d27ad94b..78cdffdbb29 100644 --- a/src/gpu/intel/jit/v2/conv/problem.hpp +++ b/src/gpu/intel/jit/v2/conv/problem.hpp @@ -58,8 +58,10 @@ class problem_t { void set_shape(const prb_tile_t &shape) { shape_ = shape; } void set_bias(bool with_bias) { with_bias_ = with_bias; } bool with_bias() const { return with_bias_; } + double ops() const; void set_shape(const std::string &s); + void normalize(); std::string desc_str() const; std::string str() const; std::string csv_str() const; @@ -67,6 +69,7 @@ class problem_t { IR_DEFINE_DUMP() static prb_tile_t default_shape(); + static double ops(prop_kind_t prop, const prb_tile_t &shape); private: hw_t hw_; @@ -76,6 +79,7 @@ class problem_t { layout_tag_t wei_tag_; layout_tag_t dst_tag_; prb_tile_t shape_; + std::array dhw_map_; }; tensor_config_t get_tensor_config(prop_kind_t prop, bool with_bias); diff --git a/src/gpu/intel/jit/v2/ir/reqs.cpp b/src/gpu/intel/jit/v2/ir/reqs.cpp index f0d7ce4f99f..162cdbdddf3 100644 --- a/src/gpu/intel/jit/v2/ir/reqs.cpp +++ b/src/gpu/intel/jit/v2/ir/reqs.cpp @@ -207,7 +207,7 @@ class req_impl_t { } req_kind_t kind() const { return kind_; } - dim_product_t lhs() const { return lhs_; } + const dim_product_t &lhs() const { return lhs_; } int rhs() const { return rhs_; } void substitute(const prb_tile_t &dim_sizes) { @@ -247,12 +247,14 @@ class req_impl_t { // requirement. bool can_prove(const req_impl_t &to_prove) const { if (*this == to_prove) return true; - if (kind_ != req_kind_t::mod_eq_0 - && to_prove.kind_ != req_kind_t::mod_eq_0) - return false; + if (kind_ != to_prove.kind_) return false; if (lhs_ != to_prove.lhs_) return false; - if (rhs_ % to_prove.rhs_ == 0) return true; - return false; + switch (kind_) { + case req_kind_t::ge: return rhs_ >= to_prove.rhs_; + case req_kind_t::le: return rhs_ <= to_prove.rhs_; + case req_kind_t::mod_eq_0: return rhs_ % to_prove.rhs_ == 0; + default: return false; + } } void stringify(std::ostream &out) const { stringify_impl(out); } @@ -367,6 +369,10 @@ void prb_reqs_t::set(const prb_dim_t &dim, int value) { add(size_var(dim) == value); } +void prb_reqs_t::set_any_mod(const prb_dim_t &dim) { + any_mods_.push_back(dim); +} + void prb_reqs_t::add_if_not_found(const req_impl_t &new_req) { for (auto &r : reqs_) { if (r.impl() == new_req) return; @@ -374,9 +380,8 @@ void prb_reqs_t::add_if_not_found(const req_impl_t &new_req) { reqs_.emplace_back(new_req); } -prover_t prb_reqs_t::prover(bool enable) { - if (!enable) return prover_t(); - return prover_t(this); +prover_t prb_reqs_t::prover(const prb_reqs_t &parent, bool can_update) { + return prover_t(&parent, this, can_update); } bool prb_reqs_t::fits(const prb_tile_t &sizes) const { @@ -562,10 +567,23 @@ bool prb_reqs_t::can_prove(const expr_t &to_prove) const { return can_prove(req_impl_t(e)); } -bool prb_reqs_t::can_prove(const req_impl_t &to_prove) const { +bool prb_reqs_t::can_prove(const req_impl_t &to_prove, bool use_any_mod) const { for (auto &r : reqs_) { if (r.impl().can_prove(to_prove)) return true; } + if (to_prove.kind() == req_kind_t::mod_eq_0) { + int mod = 1; + for (int i = 0; i < to_prove.lhs().size(); i++) { + auto &dim = to_prove.lhs()[i]; + if (use_any_mod) { + for (auto &d : any_mods_) { + if (d == dim) return true; + } + } + mod *= max_factor(dim); + } + if (mod % to_prove.rhs() == 0) return true; + } return false; } @@ -581,6 +599,17 @@ bool prb_reqs_t::get_value(const prb_dim_t &dim, int &value) const { return false; } +int prb_reqs_t::max_factor(const prb_dim_t &dim) const { + int ret = 1; + for (auto &r : reqs_) { + auto &ri = r.impl(); + if (ri.kind() == req_kind_t::mod_eq_0 && ri.lhs() == dim) { + ret = std::max(ret, ri.rhs()); + } + } + return ret; +} + bool prb_reqs_t::is_equal(const prb_dim_t &dim, int value) const { int dim_value; return get_value(dim, dim_value) && dim_value == value; @@ -588,7 +617,7 @@ bool prb_reqs_t::is_equal(const prb_dim_t &dim, int value) const { bool prb_reqs_t::implies(const prb_reqs_t &other) const { for (auto &req : other.reqs_) { - if (!can_prove(req.impl())) return false; + ir_check(can_prove(req.impl())) << "Cannot prove: " << req.impl(); } return true; } @@ -608,8 +637,10 @@ bool prover_t::require(const expr_t &_e) const { auto e = simplify_expr(_e); if (auto *imm = e.as_ptr()) return imm->value; - if (!parent_) return false; - parent_->add_if_not_found(req_impl_t(e)); + req_impl_t ri(e); + bool is_true = (parent_ && parent_->can_prove(ri, /*use_any_mod=*/true)); + if (!is_true && !can_update_) return false; + reqs_->add_if_not_found(ri); return true; } diff --git a/src/gpu/intel/jit/v2/ir/reqs.hpp b/src/gpu/intel/jit/v2/ir/reqs.hpp index ef82df7db5e..2051908f148 100644 --- a/src/gpu/intel/jit/v2/ir/reqs.hpp +++ b/src/gpu/intel/jit/v2/ir/reqs.hpp @@ -37,12 +37,22 @@ class prover_t { public: static const prover_t &instance(); prover_t() = default; - prover_t(prb_reqs_t *parent) : parent_(parent) {} + prover_t(const prb_reqs_t *parent, prb_reqs_t *reqs, bool can_update) + : parent_(parent), reqs_(reqs), can_update_(can_update) {} + prover_t(prover_t &other, bool can_update) + : parent_(other.parent_), reqs_(other.reqs_), can_update_(can_update) {} + // TODO: Change to non-const. bool require(const expr_t &e) const; - explicit operator bool() const { return parent_; } + const prb_reqs_t &reqs() const { + ir_assert(reqs_); + return *reqs_; + } + explicit operator bool() const { return reqs_; } private: - prb_reqs_t *parent_ = nullptr; + const prb_reqs_t *parent_ = nullptr; + prb_reqs_t *reqs_ = nullptr; + bool can_update_ = false; }; class req_impl_t; @@ -57,7 +67,10 @@ class prb_reqs_t { void add(const prb_reqs_t &other); void add(const prb_tile_t &tile); void set(const prb_dim_t &dim, int value); - prover_t prover(bool enable = true); + // Mark the dimension as being divisible by any number - this changes + // behavior of methods like can_prove() and max_factor(). + void set_any_mod(const prb_dim_t &dim); + prover_t prover(const prb_reqs_t &parent, bool can_update = true); explicit operator bool() const { return !reqs_.empty(); } // Checks if the requirements are satisfied for the given problem sizes . @@ -68,8 +81,9 @@ class prb_reqs_t { // For example: prb_reqs_t(oc % 64 == 0) implies (oc % 16) == 0 so the // latter can be proven from the original requirements. bool can_prove(const expr_t &to_prove) const; - bool can_prove(const req_impl_t &to_prove) const; + bool can_prove(const req_impl_t &to_prove, bool use_any_mod = false) const; bool get_value(const prb_dim_t &dim, int &value) const; + int max_factor(const prb_dim_t &dim) const; bool is_equal(const prb_dim_t &dim, int value) const; // Checks if other prb_reqs_t object is fully implied from the requirements // of this object. @@ -97,6 +111,9 @@ class prb_reqs_t { void add_if_not_found(const req_impl_t &new_req); std::vector reqs_; + // List of dimensions that are treated as having any arbitrary factors + // during proving. + std::vector any_mods_; }; } // namespace v2 diff --git a/src/gpu/intel/jit/v2/ir/send.hpp b/src/gpu/intel/jit/v2/ir/send.hpp index 967d245d6d7..6b365bc64df 100644 --- a/src/gpu/intel/jit/v2/ir/send.hpp +++ b/src/gpu/intel/jit/v2/ir/send.hpp @@ -18,6 +18,7 @@ #define GPU_INTEL_JIT_V2_IR_SEND_HPP #include "gpu/intel/jit/ir/block_2d_utils.hpp" +#include "gpu/intel/jit/ir/fma.hpp" #include "gpu/intel/jit/v2/ir/plan_utils.hpp" #include "gpu/intel/jit/v2/ir/reqs.hpp" #include "gpu/intel/jit/v2/ir/tensor.hpp" @@ -323,6 +324,7 @@ struct send_params_t { send_2d_hint_t hint_2d; // For register payload. int max_entry_reg_size = 0; + const prb_reqs_t *external_reqs = nullptr; std::vector skip_mask; void init_max_entry_reg_size() { @@ -361,13 +363,13 @@ struct send_1d_desc_t { explicit operator bool() const { return op != send_op_t::undef; } - bool base_alignment_ok(const expr_t &off, const prover_t &prover) { + bool base_alignment_ok(const expr_t &off, const prover_t &prover) const { int align = (type_size >= 16 ? 8 : 1); if (!prover.require(off % align == 0)) return false; return true; } - bool base_alignment_ok(const addr_t &addr, const prover_t &prover) { + bool base_alignment_ok(const addr_t &addr, const prover_t &prover) const { if (!base_alignment_ok(addr.base, prover)) return false; for (auto &inc : addr.slot_incs) { if (!base_alignment_ok(inc, prover)) return false; @@ -427,16 +429,16 @@ struct send_1d_plan_t : public base_plan_t { if (!desc.base_alignment_ok(addr_inc, prover)) return false; std::vector mask_incs(nmasks()); auto coord = it.coord(); + ir_assert(reg_layout.offset_in_bytes(coord) == reg_off); for (int i = 0; i < nmasks(); i++) { mask_incs[i] = mask_desc[i].to_expr(coord, /*with_const=*/false); } entries.emplace_back(); auto &e = entries.back(); - e.addr_inc = addr_inc; - e.mask_incs = mask_incs; + e.addr_inc = std::move(addr_inc); + e.mask_incs = std::move(mask_incs); e.reg_off = reg_off; - e.coord = coord; - ir_assert(reg_layout.offset_in_bytes(coord) == reg_off); + e.coord = std::move(coord); return true; } @@ -734,28 +736,30 @@ class send_plan_builder_t { send_plan_t build() const { send_params_t params = init_params_; + prb_reqs_t reqs; + auto prover = reqs.prover(*params.external_reqs, + /*can_update=*/params.kind != send_kind_t::undef); switch (params.kind) { - case send_kind_t::_2d: return try_build_2d(params, init_view_); + case send_kind_t::_2d: + return try_build_2d(params, init_view_, prover); case send_kind_t::compressed_prefetch: { - prb_reqs_t reqs; int cache_line_size = params.hw.cache_line_size(); - auto view - = init_view_.scatterize(cache_line_size, reqs.prover()); + auto view = init_view_.scatterize(cache_line_size, prover); if (view.is_empty()) return send_plan_t(); params.kind = send_kind_t::scattered; - return try_build_1d(params, view, reqs); + return try_build_1d(params, view, prover); } - default: return try_build_1d(params, init_view_); + default: return try_build_1d(params, init_view_, prover); } } private: send_plan_t try_build_1d(const send_params_t ¶ms, const view_t &view, - prb_reqs_t reqs = prb_reqs_t()) const { + prover_t &prover) const { send_plan_t plan(params.hw); auto &layout = view.layout(); auto &mask_desc = view.mask_desc(); - auto inner_last = find_inner_last(params, view, mask_desc, reqs); + auto inner_last = find_inner_last(params, view, mask_desc, prover); int type_size = layout.type().size(); int inner_elems = inner_last.elems(); int inner_bytes = type_size * inner_elems; @@ -777,7 +781,7 @@ class send_plan_builder_t { int slot_stride = std::max(4, slot_size); auto inner_end = inner_last + 1; - auto middle_last = inner_last; + auto middle_last = std::move(inner_last); auto outer_begin = end(layout); if (is_scattered) { // Add blocks to fill up slots in the scattered message. @@ -803,7 +807,7 @@ class send_plan_builder_t { desc.slots = slots; addr_t addr(layout, slots, elems_per_slot); - if (!desc.base_alignment_ok(addr, reqs.prover())) return send_plan_t(); + if (!desc.base_alignment_ok(addr, prover)) return send_plan_t(); int elem_stride = 1; if (slot_stride > slot_size) { @@ -821,33 +825,32 @@ class send_plan_builder_t { auto &plan_1d = plan.get_1d(); plan_1d = send_1d_plan_t(plan.hw); plan_1d.desc = desc; - plan_1d.addr = addr; + plan_1d.addr = std::move(addr); plan_1d.mask = mask_t(mask_desc, layout, slots, elems_per_slot); - plan_1d.reg_layout = reg_layout; - plan_1d.entry_tile = entry_tile; + plan_1d.reg_layout = std::move(reg_layout); + plan_1d.entry_tile = std::move(entry_tile); for (auto &d : params.skip_mask) plan_1d.mask.clear(d); int step_elems = slots * elems_per_slot; layout_iterator_t it(layout); int reg_off = 0; - plan_1d.add_entry(it, mask_desc, reg_off, reqs.prover()); + plan_1d.add_entry(it, mask_desc, reg_off, prover); while (it.has_next(step_elems)) { it.next(step_elems); reg_off += slots * slot_stride; reg_off = utils::rnd_up(reg_off, grf_size); - if (!plan_1d.add_entry(it, mask_desc, reg_off, reqs.prover())) + if (!plan_1d.add_entry(it, mask_desc, reg_off, prover)) return send_plan_t(); } - plan_1d.reqs = reqs; - plan_1d.reqs.simplify(); + plan_1d.reqs = prover.reqs(); return plan; } send_plan_t try_build_2d(const send_params_t ¶ms, const view_t &view, - prb_reqs_t reqs = prb_reqs_t()) const { + prover_t &prover) const { send_plan_t plan(params.hw); - send_2d_desc_t desc(view, params, reqs.prover()); + send_2d_desc_t desc(view, params, prover); if (!desc) return send_plan_t(); auto &plane = view.plane(); @@ -872,8 +875,8 @@ class send_plan_builder_t { plan_2d.mask.clear(plane.y_dim); for (auto &d : params.skip_mask) plan_2d.mask.clear(d); - plan_2d.reg_layout = reg_layout; - plan_2d.entry_tile = entry_tile; + plan_2d.reg_layout = std::move(reg_layout); + plan_2d.entry_tile = std::move(entry_tile); int reg_off = 0; for (int h = 0; h < plane.h; h += desc.h) { @@ -881,19 +884,18 @@ class send_plan_builder_t { prb_coord_t coord; coord[plane.w_dim] = w; coord[plane.h_dim] = h; - if (!plan_2d.add_entry(coord, reg_off, reqs.prover())) + if (!plan_2d.add_entry(coord, reg_off, prover)) return send_plan_t(); reg_off += entry_reg_size; } } - plan_2d.reqs = reqs; - plan_2d.reqs.simplify(); + plan_2d.reqs = prover.reqs(); return plan; } block_iterator_t find_inner_last(const send_params_t ¶ms, const view_t &view, const mask_desc_t &mask_desc, - prb_reqs_t &reqs) const { + prover_t &prover) const { auto &layout = view.layout(); auto inner_last = begin(layout); int type_size = layout.type().size(); @@ -903,8 +905,8 @@ class send_plan_builder_t { return type_size * inner_last.elems() >= grf_size; }; for (auto it = begin(layout); it != end(layout); ++it) { - auto prover = reqs.prover(!ok_to_return()); - if (!mask_desc.is_uniform(it, prover)) break; + auto _prover = prover_t(prover, /*can_update=*/!ok_to_return()); + if (!mask_desc.is_uniform(it, _prover)) break; if (!it.is_dense()) break; if (type_size * it.elems() > params.max_entry_reg_size) break; inner_last = it; diff --git a/src/gpu/intel/jit/v2/ir/tensor.cpp b/src/gpu/intel/jit/v2/ir/tensor.cpp index fb53d54543d..e7afa558385 100644 --- a/src/gpu/intel/jit/v2/ir/tensor.cpp +++ b/src/gpu/intel/jit/v2/ir/tensor.cpp @@ -183,7 +183,7 @@ void layout_raw_tag_t::add_dim(char letter, int pos) { if (new_letter >= letter) new_letter++; new_entries.emplace_back(new_letter, e.block, e.is_blocked); } - entries_ = new_entries; + entries_ = std::move(new_entries); } void layout_raw_tag_t::remove_dim(char letter) { @@ -195,7 +195,7 @@ void layout_raw_tag_t::remove_dim(char letter) { if (e.letter > letter) new_letter--; new_entries.emplace_back(new_letter, e.block, e.is_blocked); } - entries_ = new_entries; + entries_ = std::move(new_entries); } bool layout_raw_tag_t::is_blocked(char letter) const { @@ -277,7 +277,7 @@ void layout_raw_tag_t::expand_x(int ndims) { new_entries.push_back(e); } } - entries_ = new_entries; + entries_ = std::move(new_entries); } layout_raw_tag_t layout_raw_tag_t::collapse_x() const { @@ -664,18 +664,54 @@ layout_t layout_t::split_block( } template -struct div_helper_t { - static T call(const T &a, int b) { return a / b; } +struct try_div_mod { + static bool call(const T &a, int b, const var_range_info_t &range_info, + T &div, T &mod) { + if (a % b != 0) return false; + div = a / b; + mod = a % b; + return true; + } }; template <> -struct div_helper_t { - static expr_t call(const expr_t &a, int b) { return linear_div(a, b); } +struct try_div_mod { + static bool call(const expr_t &a, int b, const var_range_info_t &range_info, + expr_t &div, expr_t &mod) { + int factor = linear_max_pow2_divisor(a); + if (factor % b == 0) { + div = linear_div(a, b); + mod = expr_t(0); + return true; + } + auto _linear = to_linear(a); + auto &linear = _linear.as(); + int c_factor = linear_max_pow2_divisor(linear.c); + if (c_factor % b != 0) return false; + expr_t a_div = linear_div(linear.c, b); + expr_t a_mod; + for (int i = 0; i < linear.nargs(); i++) { + auto &u = linear.u_vec[i]; + auto &v = linear.v_vec[i]; + int u_factor = linear_max_pow2_divisor(u); + if (u_factor % b == 0) { + a_div += linear_div(u, b) * v; + continue; + } + if (range_info.bound(v) > b) return false; + if (!a_mod.is_empty()) return false; + a_mod = v; + } + div = a_div; + mod = a_mod; + return true; + } }; template layout_t layout_t::map(const dim_mapper_t &dim_mapper, - const prb_coord_t &coord, const prb_tile_t &tile) const { + const prb_coord_t &coord, const prb_tile_t &tile, + const var_range_info_t &var_range_info) const { auto idxs = coord; auto rem_sizes = tile; idxs.fill_missing(0); @@ -700,7 +736,7 @@ layout_t layout_t::map(const dim_mapper_t &dim_mapper, int inner = cur_size; int outer = b_size / cur_size; return split_block(&b, inner, outer) - .map(dim_mapper, coord, tile); + .map(dim_mapper, coord, tile, var_range_info); } return layout_t(); } @@ -713,10 +749,14 @@ layout_t layout_t::map(const dim_mapper_t &dim_mapper, } bool is_outer = true; if (b.has_const_size()) { + ir_assert(is_zero(off)); ir_assert(!seen_outer.has(dim)); - int factor = linear_max_pow2_divisor(idxs[dim]); - if (factor % b.int_size() == 0) { - idxs[dim] = div_helper_t::call(idxs[dim], b.int_size()); + T div = T(); + T mod = T(); + if (try_div_mod::call(idxs[dim], b.int_size(), + var_range_info, div, mod)) { + idxs[dim] = div; + off = mod; is_outer = false; } } @@ -732,9 +772,11 @@ layout_t layout_t::map(const dim_mapper_t &dim_mapper, } template layout_t layout_t::map(const dim_mapper_t &dim_mapper, - const prb_coord_t &coord, const prb_tile_t &tile) const; + const prb_coord_t &coord, const prb_tile_t &tile, + const var_range_info_t &var_range_info) const; template layout_t layout_t::map(const dim_mapper_t &dim_mapper, - const prb_coord_t &coord, const prb_tile_t &tile) const; + const prb_coord_t &coord, const prb_tile_t &tile, + const var_range_info_t &var_range_info) const; prb_coord_t layout_t::to_coord(const std::vector &block_idx) const { ir_assert((int)block_idx.size() == nblocks()); @@ -1192,8 +1234,9 @@ bool grid_splitter_t::is_empty() const { return true; } -expr_t grid_splitter_t::pop(int size) { +expr_t grid_splitter_t::pop(int _size) { expr_t cur = 0; + int size = _size; for (auto &idx : idxs_) { if (idx.size == 1) continue; if (size == 1) break; @@ -1201,7 +1244,7 @@ expr_t grid_splitter_t::pop(int size) { cur += idx.pop(size); } ir_assert(size == 1); - return register_index(simplify_rewrite(cur)); + return register_index(simplify_rewrite(cur), _size); } expr_t grid_splitter_t::index_t::pop(int &n) { @@ -1222,23 +1265,25 @@ expr_t grid_splitter_t::index_t::pop(int &n) { return ret; } -expr_t grid_splitter_t::register_index(const expr_t &expr) { +expr_t grid_splitter_t::register_index(const expr_t &expr, int size) { if (expr.is()) return expr; int idx = (int)virt_grid_idxs_.size(); auto var = var_t::make(type_t::s32(), "virt_grid_idx" + std::to_string(idx)); virt_grid_idxs_.emplace(var, expr); + var_range_info_.set_bound(var, size); return var; } view_t::view_t(const dim_mapper_t &dim_mapper, const layout_t &base_layout, - const prb_coord_t &coord, const prb_tile_t &tile) + const prb_coord_t &coord, const prb_tile_t &tile, + const var_range_info_t &var_range_info) : dim_mapper_(dim_mapper) , base_layout_(base_layout) , coord_(coord) , tile_(tile) { mask_desc_t base_mask_desc(dim_mapper, base_layout); - layout_ = base_layout.map(dim_mapper, coord, tile); + layout_ = base_layout.map(dim_mapper, coord, tile, var_range_info); mask_desc_ = base_mask_desc.map(coord); plane_ = plane_t(layout_, mask_desc_); } @@ -1369,7 +1414,8 @@ view_t view_t::split(const dim_mapper_t &dim_mapper, inner_dims[b.dim] *= b.int_size(); } ir_assert(grid_splitter.is_empty()); - return view_t(dim_mapper, base_layout, split_coord, split_tile); + return view_t(dim_mapper, base_layout, split_coord, split_tile, + grid_splitter.var_range_info()); } } // namespace v2 diff --git a/src/gpu/intel/jit/v2/ir/tensor.hpp b/src/gpu/intel/jit/v2/ir/tensor.hpp index b3ddce60f8b..643b66ef427 100644 --- a/src/gpu/intel/jit/v2/ir/tensor.hpp +++ b/src/gpu/intel/jit/v2/ir/tensor.hpp @@ -33,6 +33,41 @@ namespace intel { namespace jit { namespace v2 { +// Stores upper bounds for variables. +class var_range_info_t { +public: + void set_bound(const expr_t &var, int bound) { + for (auto &e : entries_) { + if (e.var.is_equal(var)) { + e.bound = std::min(e.bound, bound); + return; + } + } + entries_.emplace_back(entry_t {var, bound}); + } + + int bound(const expr_t &var) const { + for (auto &e : entries_) { + if (e.var.is_equal(var)) return e.bound; + } + return default_bound; + } + +private: + static const int default_bound = std::numeric_limits::max(); + + struct entry_t { + expr_t var; + int bound = default_bound; + + entry_t() = default; + entry_t(const expr_t &var, int bound) : var(var), bound(bound) {} + }; + +private: + std::vector entries_; +}; + struct block_t { block_t() = default; block_t(const prb_dim_t &dim, const expr_t &size, @@ -355,7 +390,8 @@ class layout_t { template layout_t map(const dim_mapper_t &dim_mapper, const prb_coord_t &coord, - const prb_tile_t &tile) const; + const prb_tile_t &tile, + const var_range_info_t &var_range_info = {}) const; template layout_t map(const prb_coord_t &coord, const prb_tile_t &tile) const { @@ -572,6 +608,21 @@ class grid_splitter_t { return virt_grid_idxs_; } + const var_range_info_t &var_range_info() const { return var_range_info_; } + + std::string str() const { + std::ostringstream oss; + bool is_first = true; + for (auto &kv : virt_grid_idxs_) { + if (!is_first) oss << "\n"; + oss << kv.first << " -> " << kv.second; + is_first = false; + } + return oss.str(); + } + + IR_DEFINE_DUMP() + private: struct index_t { expr_t expr; @@ -581,17 +632,19 @@ class grid_splitter_t { expr_t pop(int &n); }; - expr_t register_index(const expr_t &expr); + expr_t register_index(const expr_t &expr, int size); std::vector idxs_; object_map_t virt_grid_idxs_; + var_range_info_t var_range_info_; }; class view_t { public: view_t() = default; view_t(const dim_mapper_t &dim_mapper, const layout_t &base_layout, - const prb_coord_t &coord, const prb_tile_t &tile); + const prb_coord_t &coord, const prb_tile_t &tile, + const var_range_info_t &var_range_info = {}); bool is_empty() const { return base_layout_.is_empty(); } const dim_mapper_t &dim_mapper() const { return dim_mapper_; } const layout_t &base_layout() const { return base_layout_; } diff --git a/src/gpu/intel/microkernels/package.hpp b/src/gpu/intel/microkernels/package.hpp index f96c171526b..acca8342868 100644 --- a/src/gpu/intel/microkernels/package.hpp +++ b/src/gpu/intel/microkernels/package.hpp @@ -69,10 +69,10 @@ struct Package { // Contiguous span of register space. struct RegisterRange { - uint32_t boffset; // Byte offset into GRF - uint32_t blen; // Length of range in bytes + uint32_t boffset = 0; // Byte offset into GRF + uint32_t blen = 0; // Length of range in bytes - RegisterRange() {} + RegisterRange() = default; RegisterRange(uint32_t boffset_, uint32_t blen_) : boffset(boffset_), blen(blen_) {} }; diff --git a/src/gpu/intel/ocl/bnorm/bnorm_model.cpp b/src/gpu/intel/ocl/bnorm/bnorm_model.cpp index be982cd29eb..9cececbaf9a 100644 --- a/src/gpu/intel/ocl/bnorm/bnorm_model.cpp +++ b/src/gpu/intel/ocl/bnorm/bnorm_model.cpp @@ -430,10 +430,8 @@ float get_ss_utilization_factor(float util, data_type_t dt, bool is_reusable) { if (is_reusable) { if (dt == data_type::f16 || dt == data_type::bf16) { return get_pow_ratio(util, 1.0f, 2.0f, -0.8f); - } else { - return get_pow_ratio(util, 1.0f, 5.3f, -0.7f); } - return 1.f; + return get_pow_ratio(util, 1.0f, 5.3f, -0.7f); } else { return 1.f / std::min(util, 1.f); } diff --git a/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl b/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl index 76e526b4b02..f8181ea8990 100644 --- a/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl +++ b/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl @@ -79,7 +79,8 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias, __global DST_DATA_T *dst POST_OP_ARGS, __global SPAD_DATA_T *scratchpad, global float *a_scales, global WEI_SCALES_DATA_T *b_scales, - global float *c_scales, int scale_stride, global int *dst_zp) { + global DST_SCALES_DATA_T *c_scales, int scale_stride, + global int *dst_zp) { const uint d0 = GWS_GET_D0(); const uint d1 = GWS_GET_D1(); const uint d2 = GWS_GET_D2(); @@ -148,7 +149,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias, #endif #if C_SCALES - accumulator /= c_scales[0]; + accumulator /= DST_SCALES_TO_REF(c_scales[0]); #endif #if DST_ZERO_POINT accumulator += dst_zp[0]; diff --git a/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp b/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp index 4c54a4ffec3..2562d54f869 100644 --- a/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp +++ b/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp @@ -185,6 +185,8 @@ status_t gemm_with_post_ops_t::pd_t::init_kernel_ctx( kernel_ctx.define_int("C_SCALES", with_dst_scales); def_data_type(kernel_ctx, attr_scales.get(DNNL_ARG_WEIGHTS).data_type_, "WEI_SCALES"); + def_data_type( + kernel_ctx, attr_scales.get(DNNL_ARG_DST).data_type_, "DST_SCALES"); int dst_zp_mask; attr()->zero_points_.get(DNNL_ARG_DST, &dst_zp_mask); kernel_ctx.define_int("DST_ZERO_POINT", diff --git a/src/gpu/intel/ocl/gemm_matmul.hpp b/src/gpu/intel/ocl/gemm_matmul.hpp index df35883c334..c67c9294be6 100644 --- a/src/gpu/intel/ocl/gemm_matmul.hpp +++ b/src/gpu/intel/ocl/gemm_matmul.hpp @@ -236,7 +236,7 @@ struct gemm_matmul_t : public gpu_primitive_t { CHECK(map_gemm_zp(DNNL_ARG_WEIGHTS, DNNL_ARG_A, true, orig_dims - reshape_size)); post_ops = tmp_post_ops; - gemm_attr.scales_ = scales; + gemm_attr.scales_ = std::move(scales); a_md = &a_md_reshaped; b_md = &b_md_reshaped; c_md = &c_md_reshaped; diff --git a/src/gpu/intel/ocl/micro_sdpa.cl b/src/gpu/intel/ocl/micro_sdpa.cl index 208185c7ba3..206b28d5617 100644 --- a/src/gpu/intel/ocl/micro_sdpa.cl +++ b/src/gpu/intel/ocl/micro_sdpa.cl @@ -198,9 +198,9 @@ micro_sdpa(const global half *K, const global half *Q, const global half *V, const bool need_sum_barrier = (ugemm_vs_barrier_count == 0); /* Locate K/Q/V/A matrices within batch */ - K += KEY_OFF(b1, b0 % KEY_D1, 0, 0); + K += KEY_OFF(b1, b0 / KV_GROUP_SIZE, 0, 0); Q += QRY_OFF(b1, b0, 0, 0); - V += VAL_OFF(b1, b0 % VAL_D1, 0, 0); + V += VAL_OFF(b1, b0 / KV_GROUP_SIZE, 0, 0); A += DST_OFF(b1, b0, 0, 0, 0); msk += MSK_OFF(b1 % MSK_D0, b0 % MSK_D1, 0, 0); diff --git a/src/gpu/intel/ocl/micro_sdpa.cpp b/src/gpu/intel/ocl/micro_sdpa.cpp index 88ba2882563..20b4a7309a6 100644 --- a/src/gpu/intel/ocl/micro_sdpa.cpp +++ b/src/gpu/intel/ocl/micro_sdpa.cpp @@ -300,6 +300,9 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) { def_offsets(msk_off, kernel_ctx, "MSK", ndims); kernel_ctx.define_int("NDIMS", ndims); + auto Q_num_heads_dim = qry_mdw.dims()[1]; + kernel_ctx.define_int("KV_GROUP_SIZE", Q_num_heads_dim / d->kv_head_number); + auto ldq = gemm_desc_t::get_ld(*pd()->qry_md()) * qry_mdw.data_type_size(); auto ldk = gemm_desc_t::get_ld(*pd()->key_md()) * key_mdw.data_type_size(); auto ldv = gemm_desc_t::get_ld(*pd()->val_md()) * val_mdw.data_type_size(); diff --git a/src/gpu/intel/ocl/micro_sdpa.hpp b/src/gpu/intel/ocl/micro_sdpa.hpp index dc06d76c40b..a210ba994c1 100644 --- a/src/gpu/intel/ocl/micro_sdpa.hpp +++ b/src/gpu/intel/ocl/micro_sdpa.hpp @@ -118,7 +118,7 @@ struct micro_sdpa_t : public gpu_primitive_t { private: micro::Package gemm_kq_, gemm_vs_; int sg_size_ = 0; - compute::gpu_arch_t arch_; + compute::gpu_arch_t arch_ = compute::gpu_arch_t::unknown; status_t init_microkernels(impl::engine_t *engine); }; diff --git a/src/gpu/intel/ocl/ocl_math_utils.h b/src/gpu/intel/ocl/ocl_math_utils.h index cebc6f6bfa6..bfbe7d41778 100644 --- a/src/gpu/intel/ocl/ocl_math_utils.h +++ b/src/gpu/intel/ocl/ocl_math_utils.h @@ -68,6 +68,11 @@ ushort16 __builtin_IB_simd_block_read_16_global_h(const __global ushort *); void __builtin_IB_simd_block_write_8_global_l(__global ulong *, ulong8); void __builtin_IB_simd_block_write_16_global_h(__global ushort *, ushort16); +float __attribute__((overloadable)) cvt_e8m0_to_f32(uchar f) { + if (f == (uchar)0xff) return as_float(0xffc00000); + uint bits = f << 23; + return as_float(bits); +} #if MATH_UTILS_DECLARE_HF8 // Emulation functions for f8_e4m3 <-> f16 conversion. diff --git a/src/gpu/intel/ocl/ocl_stream.cpp b/src/gpu/intel/ocl/ocl_stream.cpp index 3bf87bb8958..83f1389d775 100644 --- a/src/gpu/intel/ocl/ocl_stream.cpp +++ b/src/gpu/intel/ocl/ocl_stream.cpp @@ -130,6 +130,10 @@ status_t ocl_stream_t::fill(const memory_storage_t &dst, uint8_t pattern, this, dst, pattern, size, deps, out_dep, profiler_.get()); } +status_t ocl_stream_t::barrier() { + return impl()->barrier(); +} + } // namespace ocl } // namespace intel } // namespace gpu diff --git a/src/gpu/intel/ocl/ocl_stream.hpp b/src/gpu/intel/ocl/ocl_stream.hpp index 6eb05c18888..a08f9daef81 100644 --- a/src/gpu/intel/ocl/ocl_stream.hpp +++ b/src/gpu/intel/ocl/ocl_stream.hpp @@ -89,6 +89,8 @@ struct ocl_stream_t : public compute::compute_stream_t { status_t fill(const memory_storage_t &dst, uint8_t pattern, size_t size, const xpu::event_t &deps, xpu::event_t &out_dep) override; + status_t barrier() override; + ~ocl_stream_t() override = default; const xpu::ocl::context_t &ocl_ctx() const { return impl()->ocl_ctx(); } diff --git a/src/gpu/intel/ocl/ocl_types.h b/src/gpu/intel/ocl/ocl_types.h index 8496e5ca369..1760ea62944 100644 --- a/src/gpu/intel/ocl/ocl_types.h +++ b/src/gpu/intel/ocl/ocl_types.h @@ -1460,6 +1460,18 @@ #endif #endif +#ifdef DST_SCALES_DATA_T +#if DST_SCALES_DT_F16 +#define DST_SCALES_TO_REF(x) convert_float(x) +#elif DST_SCALES_DT_BF16 +#define DST_SCALES_TO_REF(x) cvt_bf16_to_f32(x) +#elif DST_SCALES_DT_E8M0 +#define DST_SCALES_TO_REF(x) cvt_e8m0_to_f32(x) +#else +#define DST_SCALES_TO_REF(x) (x) +#endif +#endif + #ifdef WEI_SCALES_DATA_T #if WEI_SCALES_DT_F16 #define WEI_SCALES_TO_REF(x) convert_float(x) diff --git a/src/gpu/intel/ocl/ref_matmul.cl b/src/gpu/intel/ocl/ref_matmul.cl index a62e0e77a0e..ad2fe34f06f 100644 --- a/src/gpu/intel/ocl/ref_matmul.cl +++ b/src/gpu/intel/ocl/ref_matmul.cl @@ -42,14 +42,15 @@ __kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B, long src_scale_stride_m, long src_scale_group_k, __global WEI_SCALES_DATA_T *wei_scales, long wei_scale_stride_n, long wei_scale_stride_k, long wei_scale_group_k, - __global float *dst_scales, long group_K, long K, long N, long M, - long D0, long D1, long D2, long bia_stride_d3, long bia_stride_d2, - long bia_stride_d1, long bia_stride_d0, long bia_stride_m, - long bia_stride_n, long a_stride_d3, long a_stride_d2, long a_stride_d1, - long a_stride_d0, long a_stride_m, long a_stride_k, long b_stride_d3, - long b_stride_d2, long b_stride_d1, long b_stride_d0, long b_stride_k, - long b_stride_n, long c_stride_d3, long c_stride_d2, long c_stride_d1, - long c_stride_d0, long c_stride_m, long c_stride_n + __global DST_SCALES_DATA_T *dst_scales, long group_K, long K, long N, + long M, long D0, long D1, long D2, long bia_stride_d3, + long bia_stride_d2, long bia_stride_d1, long bia_stride_d0, + long bia_stride_m, long bia_stride_n, long a_stride_d3, + long a_stride_d2, long a_stride_d1, long a_stride_d0, long a_stride_m, + long a_stride_k, long b_stride_d3, long b_stride_d2, long b_stride_d1, + long b_stride_d0, long b_stride_k, long b_stride_n, long c_stride_d3, + long c_stride_d2, long c_stride_d1, long c_stride_d0, long c_stride_m, + long c_stride_n #if WITH_DROPOUT , __global uchar *dropout_mask_buf, __global uint *dropout_seed_buf, @@ -228,7 +229,7 @@ __kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B, d1, 1, d0, 1, m, 1, n, 1); #if WITH_DST_SCALES - po_acc /= dst_scales[0]; + po_acc /= DST_SCALES_TO_REF(dst_scales[0]); #endif po_acc += dst_zp; C[dst_off] = TO_DST(po_acc); diff --git a/src/gpu/intel/ocl/ref_matmul.hpp b/src/gpu/intel/ocl/ref_matmul.hpp index ce8ea24b783..7c86287b3d7 100644 --- a/src/gpu/intel/ocl/ref_matmul.hpp +++ b/src/gpu/intel/ocl/ref_matmul.hpp @@ -198,6 +198,9 @@ struct ref_matmul_t : public gpu_primitive_t { def_data_type(kernel_ctx, pd()->attr()->scales_.get(DNNL_ARG_SRC).data_type_, "SRC_SCALES"); + def_data_type(kernel_ctx, + pd()->attr()->scales_.get(DNNL_ARG_DST).data_type_, + "DST_SCALES"); CHECK(create_kernel(engine, &kernel_, "ref_matmul", kernel_ctx)); if (!kernel_) return status::runtime_error; return status::success; diff --git a/src/gpu/intel/ocl/reusable_softmax.cpp b/src/gpu/intel/ocl/reusable_softmax.cpp index 3060c9c75e8..371dc765bf0 100644 --- a/src/gpu/intel/ocl/reusable_softmax.cpp +++ b/src/gpu/intel/ocl/reusable_softmax.cpp @@ -97,7 +97,7 @@ status_t reusable_softmax_fwd_t::pd_t::init_dispatch_default_reusable( auto *compute_engine = utils::downcast(engine); compute::reusable_dispatch_config_t dispatch_config( - compute_engine, dispatch_dim_ids); + compute_engine, std::move(dispatch_dim_ids)); CHECK(dispatch_config.register_buffer(src_buf)); CHECK(dispatch_config.register_buffer(dst_buf)); @@ -163,12 +163,12 @@ status_t reusable_softmax_fwd_t::pd_t::init_dispatch_workgroup_per_reduction( compute::named_buffer_t dst_buf("DST", src_buf); // dispatch: all dims except reduction dimension plus workers dimension - std::vector dispatch_dims = dims_ids; + std::vector dispatch_dims = std::move(dims_ids); dispatch_dims[softmax_axis] = softmax_dims_t::workers; auto *compute_engine = utils::downcast(engine); compute::reusable_dispatch_config_t dispatch_config( - compute_engine, dispatch_dims); + compute_engine, std::move(dispatch_dims)); CHECK(dispatch_config.register_buffer(src_buf)); CHECK(dispatch_config.register_buffer(dst_buf)); CHECK(dispatch_config.register_buffer(ori_buf)); diff --git a/src/gpu/intel/ocl/reusable_vectorized_lnorm.cpp b/src/gpu/intel/ocl/reusable_vectorized_lnorm.cpp index 1cff4c4f0e4..78a5bf6c07e 100644 --- a/src/gpu/intel/ocl/reusable_vectorized_lnorm.cpp +++ b/src/gpu/intel/ocl/reusable_vectorized_lnorm.cpp @@ -166,7 +166,8 @@ static status_t init_conf_common(const layer_normalization_pd_t *pd, auto lws_strategy = single_subgroup_lws_strategy_t( compute_engine, gpu_attr, conf->sg_size); - compute::reusable_dispatch_config_t dispatch_config(compute_engine, dims); + compute::reusable_dispatch_config_t dispatch_config( + compute_engine, std::move(dims)); CHECK(dispatch_config.register_buffer(input_buf)); CHECK(dispatch_config.register_buffer(output_buf)); CHECK(dispatch_config.register_buffer(stat_buf)); diff --git a/src/gpu/intel/primitive_conf.hpp b/src/gpu/intel/primitive_conf.hpp index 16aee7f6b88..53ce391c8ff 100644 --- a/src/gpu/intel/primitive_conf.hpp +++ b/src/gpu/intel/primitive_conf.hpp @@ -1008,6 +1008,7 @@ inline void def_data_type(compute::kernel_ctx_t &kernel_ctx, data_type_t dt, const char *bf16_name = with_punning ? "ushort" : "bf16"; const char *bf8_name = with_punning ? "uchar" : "f8_e5m2"; const char *hf8_name = with_punning ? "uchar" : "f8_e4m3"; + const char *e8m0_name = with_punning ? "uchar" : "e8m0"; const char *u4_name = with_punning ? "uchar" : "u4"; const char *s4_name = with_punning ? "uchar" : "s4"; @@ -1048,6 +1049,10 @@ inline void def_data_type(compute::kernel_ctx_t &kernel_ctx, data_type_t dt, kernel_ctx.add_option(utils::format( "-D%s_DATA_T=%s -D%s_DT_BF8", str, bf8_name, str)); break; + case data_type::e8m0: + kernel_ctx.add_option(utils::format( + "-D%s_DATA_T=%s -D%s_DT_E8M0", str, e8m0_name, str)); + break; case data_type::s4: kernel_ctx.add_option(utils::format( "-D%s_DATA_T=%s -D%s_DT_S4", str, s4_name, str)); diff --git a/src/gpu/intel/sycl/stream.cpp b/src/gpu/intel/sycl/stream.cpp index f100ad6e142..2171df4da15 100644 --- a/src/gpu/intel/sycl/stream.cpp +++ b/src/gpu/intel/sycl/stream.cpp @@ -96,6 +96,84 @@ void stream_t::after_exec_hook() { if (is_profiling_enabled()) profiler_->stop_profiling(); } +// The following code needs sycl::queue::ext_oneapi_get_graph(), but it may +// not be defined. Some SFINAE is needed to avoid compile errors in this case. +namespace syclex = ::sycl::ext::oneapi::experimental; +template +static auto get_graph_internal(const Q &q, bool &success, int) + -> decltype(q.ext_oneapi_get_graph()) { + success = true; + return q.ext_oneapi_get_graph(); +} + +template +static syclex::command_graph +get_graph_internal(const Q &q, bool &success, long) { + success = false; + return syclex::command_graph( + q.get_context(), q.get_device()); +} + +static syclex::command_graph get_graph( + const ::sycl::queue *q, bool &success) { + return get_graph_internal(*q, success, 0); +} + +bool stream_t::recording() const { + return impl()->queue()->ext_oneapi_get_state() + == syclex::queue_state::recording; +} + +stream_t::weak_graph_t stream_t::get_current_graph_weak() const { + bool success; + stream_t::weak_graph_t result = get_graph(impl()->queue(), success); + if (!success) result.reset(); + return result; +} + +status_t stream_t::enter_immediate_mode() { + std::lock_guard lock(immediate_mode_mutex_); + if (!immediate_mode_level_++) pause_recording(); + return status::success; +} + +status_t stream_t::exit_immediate_mode() { + std::lock_guard lock(immediate_mode_mutex_); + if (immediate_mode_level_ > 0) { + if (!--immediate_mode_level_) resume_recording(); + } else { + assert(!"exit_immediate_mode called without enter"); + return status::runtime_error; + } + return status::success; +} + +status_t stream_t::pause_recording() { + using graph_t = syclex::command_graph; + if (recording()) { + bool success; + assert(!paused_graph_); + paused_graph_.reset(new graph_t(get_graph(impl()->queue(), success))); + if (!success) return status::runtime_error; + paused_graph_->end_recording(); + auto &cur_dep = xpu::sycl::event_t::from(ctx().get_deps()); + paused_dep_ = xpu::sycl::event_t {}; + std::swap(paused_dep_, cur_dep); + } + return status::success; +} + +status_t stream_t::resume_recording() { + if (paused_graph_) { + paused_graph_->begin_recording(*impl()->queue()); + paused_graph_.reset(); + auto &cur_dep = xpu::sycl::event_t::from(ctx().get_deps()); + std::swap(paused_dep_, cur_dep); + paused_dep_ = xpu::sycl::event_t {}; + } + return status::success; +} + } // namespace sycl } // namespace intel } // namespace gpu diff --git a/src/gpu/intel/sycl/stream.hpp b/src/gpu/intel/sycl/stream.hpp index 92e4630baa0..1d7c7857899 100644 --- a/src/gpu/intel/sycl/stream.hpp +++ b/src/gpu/intel/sycl/stream.hpp @@ -100,6 +100,8 @@ struct stream_t : public gpu::intel::compute::compute_stream_t { return impl()->fill(dst, pattern, size, deps, out_dep, profiler_.get()); } + status_t barrier() override { return impl()->barrier(); } + const xpu::sycl::context_t &sycl_ctx() const { return impl()->sycl_ctx(); } xpu::sycl::context_t &sycl_ctx() { return impl()->sycl_ctx(); } @@ -114,6 +116,15 @@ struct stream_t : public gpu::intel::compute::compute_stream_t { return impl()->register_deps(cgh); } + bool recording() const; + using weak_graph_t = ::sycl::ext::oneapi::weak_object< + ::sycl::ext::oneapi::experimental::command_graph<::sycl::ext:: + oneapi::experimental::graph_state::modifiable>>; + weak_graph_t get_current_graph_weak() const; + + status_t enter_immediate_mode() override; + status_t exit_immediate_mode() override; + protected: xpu::sycl::stream_impl_t *impl() const { return (xpu::sycl::stream_impl_t *)impl::stream_t::impl_.get(); @@ -124,6 +135,16 @@ struct stream_t : public gpu::intel::compute::compute_stream_t { private: status_t init(); + + status_t pause_recording(); + status_t resume_recording(); + + std::mutex immediate_mode_mutex_; + int immediate_mode_level_ = 0; + std::unique_ptr<::sycl::ext::oneapi::experimental::command_graph< + ::sycl::ext::oneapi::experimental::graph_state::modifiable>> + paused_graph_; + xpu::sycl::event_t paused_dep_; }; } // namespace sycl diff --git a/src/gpu/intel/sycl/sycl_interop_gpu_kernel.cpp b/src/gpu/intel/sycl/sycl_interop_gpu_kernel.cpp index 119469ebb15..f5ff7ed800b 100644 --- a/src/gpu/intel/sycl/sycl_interop_gpu_kernel.cpp +++ b/src/gpu/intel/sycl/sycl_interop_gpu_kernel.cpp @@ -155,7 +155,7 @@ status_t sycl_interop_gpu_kernel_t::parallel_for(impl::stream_t &stream, } else if (arg.is_local()) { auto acc = xpu::sycl::compat::local_accessor( ::sycl::range<1>(arg.size()), cgh); - cgh.set_arg((int)i, acc); + cgh.set_arg((int)i, std::move(acc)); } else { set_scalar_arg(cgh, (int)i, arg.scalar_type(), arg.value()); } @@ -179,7 +179,7 @@ status_t sycl_interop_gpu_kernel_t::parallel_for(impl::stream_t &stream, gpu_stream->profiler().register_event(std::move(sycl_event)); } - xpu::sycl::event_t::from(out_dep).events = {event}; + xpu::sycl::event_t::from(out_dep).events = {std::move(event)}; return status::success; } diff --git a/src/gpu/intel/sycl/utils.cpp b/src/gpu/intel/sycl/utils.cpp index fcf309bf239..626deb8e477 100644 --- a/src/gpu/intel/sycl/utils.cpp +++ b/src/gpu/intel/sycl/utils.cpp @@ -106,20 +106,20 @@ status_t sycl_dev2ocl_dev(cl_device_id *ocl_dev, const ::sycl::device &dev) { std::vector ocl_devices; std::vector> ocl_sub_devices; - auto st = xpu::ocl::get_devices( + auto status = xpu::ocl::get_devices( &ocl_devices, &ocl_sub_devices, CL_DEVICE_TYPE_GPU); - assert(st == status::success); - MAYBE_UNUSED(st); + assert(status == status::success); + MAYBE_UNUSED(status); const auto register_ocl_dev = [&uuid2ocl_dev_tmp]( const xpu::ocl::wrapper_t &d) { xpu::device_uuid_t ocl_dev_uuid; - auto st = xpu::ocl::get_device_uuid(ocl_dev_uuid, d); - assert(st == status::success); - st = uuid2ocl_dev_tmp.add(ocl_dev_uuid, d); - assert(st == status::success); - MAYBE_UNUSED(st); + auto status = xpu::ocl::get_device_uuid(ocl_dev_uuid, d); + assert(status == status::success); + status = uuid2ocl_dev_tmp.add(std::move(ocl_dev_uuid), d); + assert(status == status::success); + MAYBE_UNUSED(status); }; for (cl_device_id d : ocl_devices) { diff --git a/src/graph/backend/dnnl/dnnl_backend.cpp b/src/graph/backend/dnnl/dnnl_backend.cpp index f4bfeba3053..e5ab0ea51e7 100644 --- a/src/graph/backend/dnnl/dnnl_backend.cpp +++ b/src/graph/backend/dnnl/dnnl_backend.cpp @@ -108,9 +108,10 @@ graph::utils::optional_t dnnl_backend_t::get_mem_desc( } // namespace dnnl_impl // This function should be called by backend_registry_t -void register_dnnl_backend() { - backend_registry_t::get_singleton().register_backend( +status_t register_dnnl_backend() { + const status_t ret = backend_registry_t::get_singleton().register_backend( &dnnl_impl::dnnl_backend_t::get_singleton()); + return ret; } } // namespace graph diff --git a/src/graph/backend/dnnl/kernels/matmul.cpp b/src/graph/backend/dnnl/kernels/matmul.cpp index 3de0c661b08..a6ce6011cc5 100644 --- a/src/graph/backend/dnnl/kernels/matmul.cpp +++ b/src/graph/backend/dnnl/kernels/matmul.cpp @@ -40,8 +40,8 @@ status_t matmul_t::compile_impl(const dnnl_partition_impl_t *part, g_alloc_ = reinterpret_cast(g_engine->get_allocator()); - subgraph_ = std::make_shared(part->get_ops(), p_engine_, - part->get_fpmath_mode(), part->get_use_blocked_layout(), true); + subgraph_ = std::make_shared( + part->get_ops(), p_engine_, part->get_fpmath_mode(), true, true); BACKEND_DNNL_CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs)); subgraph_visualizer_t vis(part->id(), [this](const value_t *val) { diff --git a/src/graph/backend/dnnl/kernels/sdp_decomp.cpp b/src/graph/backend/dnnl/kernels/sdp_decomp.cpp index 96787873331..225c9925f83 100644 --- a/src/graph/backend/dnnl/kernels/sdp_decomp.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_decomp.cpp @@ -46,8 +46,8 @@ status_t sdp_decomp_kernel_t::compile_impl( = reinterpret_cast(g_engine->get_allocator()); // get subgraph from the deep copied partition - subgraph_ = std::make_shared(part->get_ops(), p_engine_, - part->get_fpmath_mode(), part->get_use_blocked_layout(), true); + subgraph_ = std::make_shared( + part->get_ops(), p_engine_, part->get_fpmath_mode(), false, true); BACKEND_DNNL_CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs)); // Check if it's supported by decomposition kernel @@ -60,6 +60,7 @@ status_t sdp_decomp_kernel_t::compile_impl( pass_pipeline_t pipeline = pass_pipeline_t(vis); pass_pipeline_t select_pipeline = pass_pipeline_t(vis); BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_reshape_for_gqa); // Fusion and canonicalization passes begin if (quantized) { BACKEND_DNNL_ADD_PASS(pipeline, lift_up_typecast); @@ -291,7 +292,7 @@ status_t sdp_decomp_kernel_t::execute_impl( auto mask_strides = ltw(mask_input.get_logical_tensor()).vstrides(); sub_mm1_post_add_tid.set_data_handle( static_cast(mask_input.get_data_handle()) - + bo * mask_strides[1] + + bo * mask_strides[0] * get_mem_dt_size(sub_mm1_post_add_tid)); } if (sdp_cfg_.has_select) { diff --git a/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp b/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp index 0a343276b44..4a8250f61fb 100644 --- a/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp @@ -27,7 +27,7 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr &sg, // to record the input offset in a certain order of ops. auto op_status = record_input_offset(sg, inputs); if (op_status != status::success) return false; - memory::dims src1_user_dims = ltw(inputs[graph_inport[0]]).vdims(); + dims src1_user_dims = ltw(inputs[graph_inport[0]]).vdims(); if (src1_user_dims.size() != 4) return false; // Initialize SDP input dimension according to the src of mm1 @@ -36,6 +36,9 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr &sg, seq_len_q = src1_user_dims[2]; size_per_head = src1_user_dims[3]; + dims wei1_user_dims = ltw(inputs[graph_inport[1]]).vdims(); + num_head_kv = wei1_user_dims[1]; + #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP // RATIO is an empirical value used to determine the numerical relationship // between batch_size, num_head_q and thread number to determine whether to use @@ -67,10 +70,8 @@ impl::status_t sdp_decomp_config_t::construct_params( // Update SDPA input params. Sequence length for query and key/value are // NOT always same. - memory::dim seq_len_kv; const auto <_wei = sdp_op[1]->get_input_value(1)->get_logical_tensor(); const ltw ltw_wei(lt_wei); - num_head_kv = ltw_wei.vdims()[1]; seq_len_kv = ltw_wei.vdims()[3]; // Acquire the data type from input param for later primitive creation. @@ -102,7 +103,7 @@ impl::status_t sdp_decomp_config_t::construct_params( sub_reorder0_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); // per-head: reorder src1 to dense, for first matmul - memory::dims sub_src1_dims = {1, 1, seq_len_q, size_per_head}; + dims sub_src1_dims = {1, 1, seq_len_q, size_per_head}; src1_strides = ltw(inputs[graph_inport[0]]).vstrides(); sub_src1_md = memory::desc(sub_src1_dims, dt_src_user, {1, 1, src1_strides[2], src1_strides[3]}); @@ -117,7 +118,7 @@ impl::status_t sdp_decomp_config_t::construct_params( // create reorder1 primitive attr dnnl::primitive_attr sub_reorder1_attr = make_primitive_attr(sdp_op[0], mgr); - memory::dims sub_wei1_dims = {1, 1, size_per_head, seq_len_kv}; + dims sub_wei1_dims = {1, 1, size_per_head, seq_len_kv}; auto wei_md = make_dnnl_memory_desc( sdp_op[1]->get_input_value(1)->get_logical_tensor()); wei1_strides = wei_md.get_strides(); @@ -132,9 +133,9 @@ impl::status_t sdp_decomp_config_t::construct_params( // first matmul // create first matmul primitive attr dnnl::primitive_attr sub_matmul1_attr = make_primitive_attr(sdp_op[1], mgr); - memory::dims sub_mm1_src_dims = {1, 1, seq_len_q, size_per_head}; - memory::dims sub_mm1_wei_dims = {1, 1, size_per_head, seq_len_kv}; - memory::dims sub_mm1_dst_dims = {1, 1, seq_len_q, seq_len_kv}; + dims sub_mm1_src_dims = {1, 1, seq_len_q, size_per_head}; + dims sub_mm1_wei_dims = {1, 1, size_per_head, seq_len_kv}; + dims sub_mm1_dst_dims = {1, 1, seq_len_q, seq_len_kv}; sub_mm1_src_md = memory::desc(sub_mm1_src_dims, dt_src_user, tag::abcd); sub_mm1_wei_md = memory::desc(sub_mm1_wei_dims, dt_wei, tag::abdc); @@ -149,8 +150,7 @@ impl::status_t sdp_decomp_config_t::construct_params( auto post_shape = ori_desc.dims; auto post_stride = ori_desc.format_desc.blocking.strides; auto post_dt = static_cast(ori_desc.data_type); - memory::dims post_stride_dims - = memory::dims(post_stride, post_stride + ori_desc.ndims); + dims post_stride_dims = dims(post_stride, post_stride + ori_desc.ndims); auto new_sub_md = memory::desc({1, 1, post_shape[2], post_shape[3]}, post_dt, post_stride_dims); sub_mm1_post_md.emplace_back(new_sub_md); @@ -175,7 +175,7 @@ impl::status_t sdp_decomp_config_t::construct_params( // create reorder2 primitive attr dnnl::primitive_attr sub_reorder2_attr = make_primitive_attr(sdp_op[3], mgr); - memory::dims sub_wei2_dims = {1, 1, seq_len_kv, size_per_head}; + dims sub_wei2_dims = {1, 1, seq_len_kv, size_per_head}; wei2_strides = ltw(inputs[graph_inport[4]]).vstrides(); sub_wei2_user_md = memory::desc(sub_wei2_dims, dt_wei_user, {1, 1, wei2_strides[2], wei2_strides[3]}); @@ -188,9 +188,9 @@ impl::status_t sdp_decomp_config_t::construct_params( // second matmul // create second matmul primitive attr dnnl::primitive_attr sub_matmul2_attr = make_primitive_attr(sdp_op[4], mgr); - memory::dims sub_mm2_src_dims = {1, 1, seq_len_q, seq_len_kv}; - memory::dims sub_mm2_wei_dims = {1, 1, seq_len_kv, size_per_head}; - memory::dims sub_mm2_dst_dims = {1, 1, seq_len_q, size_per_head}; + dims sub_mm2_src_dims = {1, 1, seq_len_q, seq_len_kv}; + dims sub_mm2_wei_dims = {1, 1, seq_len_kv, size_per_head}; + dims sub_mm2_dst_dims = {1, 1, seq_len_q, size_per_head}; auto sub_mm2_src_md = memory::desc(sub_mm2_src_dims, dt_src_user, tag::abcd); sub_mm2_wei_md = memory::desc(sub_mm2_wei_dims, dt_wei, tag::abcd); @@ -202,7 +202,7 @@ impl::status_t sdp_decomp_config_t::construct_params( // per-head: reorder dst2 from dense to strided primitive_attr sub_reorder3_attr; sub_reorder3_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - memory::dims sub_dst_dims = {1, 1, seq_len_q, size_per_head}; + dims sub_dst_dims = {1, 1, seq_len_q, size_per_head}; auto out_lt = sdp_op[4]->get_output_value(0)->get_logical_tensor(); dst_strides = ltw(out_lt).vstrides(); sub_dst_md = memory::desc(sub_dst_dims, dt_src_user, tag::abcd); @@ -401,7 +401,10 @@ impl::status_t sdp_decomp_config_t::record_input_offset( auto find_graph_inport = [&](std::shared_ptr val) { // for quantized matmul, it has producer such as add_zp,sub_zp,mul_scale. if (val->get_consumers()[0].get_op().get_kind() - == graph::op_kind::MatMul) { + == graph::op_kind::MatMul + || (val->has_producer() + && val->get_producer().get_kind() + == graph::op_kind::StaticReshape)) { while (val->has_producer()) { val = val->get_producer().get_input_value(0); } @@ -444,11 +447,11 @@ impl::status_t sdp_decomp_config_t::record_input_offset( if (post_op) { // find mask if (post_op->get_kind() == graph::op_kind::Add) { - add = post_op; + add = std::move(post_op); has_attention_mask = true; } else if (post_op->get_kind() == graph::op_kind::Select) { // mm1 -> scale -> select -> ... - select = post_op; + select = std::move(post_op); has_select = true; } } diff --git a/src/graph/backend/dnnl/kernels/sdp_decomp_config.hpp b/src/graph/backend/dnnl/kernels/sdp_decomp_config.hpp index 866e4fe1cf3..ae16348da88 100644 --- a/src/graph/backend/dnnl/kernels/sdp_decomp_config.hpp +++ b/src/graph/backend/dnnl/kernels/sdp_decomp_config.hpp @@ -74,10 +74,11 @@ struct sdp_decomp_config_t { sdp_decomp_config_t() = default; // SDP input dimension - memory::dim batch_size, num_head_q, num_head_kv, seq_len_q, size_per_head; + dim_t batch_size, num_head_q, num_head_kv, seq_len_q, seq_len_kv, + size_per_head; // SDP input and output strides - memory::dims src1_strides, wei1_strides, wei2_strides, dst_strides, + dims src1_strides, wei1_strides, wei2_strides, dst_strides, post_add_strides; // Thread nums during the workflow diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive.cpp index 170e6fab98e..50e977994e1 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive.cpp @@ -53,12 +53,15 @@ status_t sdp_primitive_kernel_t::compile_impl(const dnnl_partition_impl_t *part, p_engine_, part->get_fpmath_mode(), false, true); CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs)); + CHECK(cfg_.initial_check(subgraph_, inputs)); + subgraph_visualizer_t vis(part->id(), [this](const value_t *val) { return this->memory_planner_.get_memory_info(val); }); pass_pipeline_t pipeline = pass_pipeline_t(vis); BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_reshape_for_gqa); BACKEND_DNNL_ADD_PASS(pipeline, binary_canonicalization); BACKEND_DNNL_ADD_PASS(pipeline, insert_permute_for_matmul); @@ -100,14 +103,6 @@ status_t sdp_primitive_kernel_t::compile_impl(const dnnl_partition_impl_t *part, CHECK(modify_subgraph()); CHECK(cfg_.init(subgraph_, p_engine_, inputs, outputs)); - // Successfully created the primitive. Rerun the passes again, modifying - // the original ops. - subgraph_ = std::make_shared( - part->get_ops(), p_engine_, part->get_fpmath_mode(), false, true); - CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs)); - CHECK(modify_subgraph()); - CHECK(cfg_.locate_io(subgraph_, inputs, outputs)); - return status::success; } diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp index dfc1480b908..9dba68fa14c 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp @@ -91,6 +91,12 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr &sg, q_ = mm1->get_input_value(0); k_ = mm1->get_input_value(1); v_ = mm2->get_input_value(1); + + auto k_follow = follow_back(k_); + for (auto &t : inputs) + if (k_follow->get_logical_tensor().id == t.id) { + kv_head_number_ = t.dims[1]; + } dst_ = (final_op->get_kind() == op_kind::dnnl_transpose) ? final_op->get_input_value(0) : final_op->get_output_value( @@ -110,6 +116,81 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr &sg, return status::success; } +status_t sdp_primitive_config_t::initial_check( + const std::shared_ptr &sg, + const std::vector &inputs) { + // At least 3 inputs: Q, K, V + if (inputs.size() < 3) return status::invalid_arguments; + + // step1(pattern check): Not support sdpa variants with select as mask + // We already have a pattern matcher to ensure that the sdpa patterns + // dispatch to here are knows ones, and we have quant check in sdpa base + // kernel, so here we only check specific variants based on support matrix. + const std::unordered_set mm1_post_op_kind + = {graph::op_kind::Divide, graph::op_kind::Multiply, + graph::op_kind::Add, graph::op_kind::Select, + graph::op_kind::SoftMax}; + op_ptr mm1 = nullptr, mm2 = nullptr; + for (const auto &cur_op : sg->get_ops()) { + if (cur_op->get_kind() != graph::op_kind::MatMul) continue; + auto post_op = get_post_op(cur_op); + if (post_op && mm1_post_op_kind.count(post_op->get_kind())) { + mm1 = cur_op; + // Not support select between mm1 and scale(optional) + // GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ... + if (post_op->get_kind() == graph::op_kind::Select) { + return status::unimplemented; + } + // scale + if (post_op->get_kind() == graph::op_kind::Divide + || post_op->get_kind() == graph::op_kind::Multiply) { + // Scale exists, update post_op and traverse to next op + post_op = get_post_op(post_op); + } + // mask + if (post_op->get_kind() == graph::op_kind::Add) { + // Mask exists, update post_op and traverse to next op + post_op = get_post_op(post_op); + } + + // Not support select after scale(optional) and mask(optional) + // Distill-Bert:[mm1] --> [scale]* --> [mask]* --> [select] --> ... + if (post_op->get_kind() == graph::op_kind::Select) { + return status::unimplemented; + } + } else { + mm2 = cur_op; + } + } + + // step2(data type check): only support fp16 now. + auto in_lt = inputs[0]; + if (in_lt.data_type != dnnl_data_type_t::dnnl_f16) + return status::unimplemented; + + auto find_graph_inport = [&inputs](const std::shared_ptr &val) { + for (int i = 0; i < (int)inputs.size(); i++) { + if (val->get_logical_tensor().id == inputs[i].id) { return i; } + } + // If the corresponding input is not found, return an invalid value + return -1; + }; + + // step3(dims check): only support 4-dims now. + int q_id = find_graph_inport(mm1->get_input_value(0)); + int k_id = find_graph_inport(mm1->get_input_value(1)); + int v_id = find_graph_inport(mm2->get_input_value(1)); + + bool ok = true; + ok = ok && (q_id != -1) && (k_id != -1) && (v_id != -1); + if (!ok) return status::unimplemented; + ok = ok && ltw(inputs[q_id]).vdims().size() == 4 + && ltw(inputs[k_id]).vdims().size() == 4 + && ltw(inputs[v_id]).vdims().size() == 4; + + return ok ? status::success : status::unimplemented; +} + status_t sdp_primitive_config_t::init(std::shared_ptr &sg, const dnnl::engine &p_engine, const std::vector &inputs, @@ -138,7 +219,7 @@ status_t sdp_primitive_config_t::init(std::shared_ptr &sg, CHECK(create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(), md_k.get(), md_v.get(), md_dst.get(), md_mask.get(), scale_dt, invert_scale_, - attr.get())); + attr.get(), kv_head_number_)); auto status = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get()); diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp index dae05c3758e..3a0c2ffaa0e 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp @@ -37,6 +37,7 @@ namespace impl { namespace graph { namespace dnnl_impl { using op_ptr = std::shared_ptr; +using ltw = logical_tensor_wrapper_t; struct sdp_primitive_config_t { public: @@ -49,6 +50,7 @@ struct sdp_primitive_config_t { std::shared_ptr scale_ = nullptr; std::shared_ptr attn_mask_ = nullptr; bool invert_scale_ = false; + dim_t kv_head_number_; // SDP pd and primitive. std::shared_ptr sdpa_pd_; @@ -62,6 +64,14 @@ struct sdp_primitive_config_t { const std::vector &inputs, const std::vector &outputs); + // The function is used to check if the configuration of SDP is supported by + // current implementation of micro kernel. Refer to the following limitation: + // 1. only support limited pattern, variants with select op are not supported + // 2. only support fp16 data type + // 3. only support 4-dims tensor + status_t initial_check(const std::shared_ptr &sg, + const std::vector &inputs); + // Initialize parameters and primitive. status_t init(std::shared_ptr &sg, const dnnl::engine &p_engine, const std::vector &inputs, diff --git a/src/graph/backend/dnnl/kernels/softmax.cpp b/src/graph/backend/dnnl/kernels/softmax.cpp index 64fa5a2ef9d..fca5ea99ef8 100644 --- a/src/graph/backend/dnnl/kernels/softmax.cpp +++ b/src/graph/backend/dnnl/kernels/softmax.cpp @@ -49,11 +49,16 @@ status_t softmax_fwd_t::compile_impl(const dnnl_partition_impl_t *part, pass_pipeline_t pipeline(vis); BACKEND_DNNL_ADD_PASS(pipeline, lower_down); - BACKEND_DNNL_ADD_PASS(pipeline, fuse_post_ops); BACKEND_DNNL_ADD_PASS(pipeline, fuse_post_typecast_to_predecessor); BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect); + BACKEND_DNNL_ADD_PASS(pipeline, replace_quant_data_with_binary_post_op); + BACKEND_DNNL_ADD_PASS(pipeline, binary_canonicalization); + BACKEND_DNNL_ADD_PASS(pipeline, binary_broadcast_swap); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_post_ops); BACKEND_DNNL_ADD_PASS(pipeline, convert_to_runtime_dst_scales); BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_scales); + BACKEND_DNNL_ADD_PASS(pipeline, infer_shape); + pipeline.reset_visualize_arg(true, false); if (enabled_constant_cache()) { diff --git a/src/graph/backend/dnnl/op_executable.cpp b/src/graph/backend/dnnl/op_executable.cpp index 106aba7f42e..73fc14f39f4 100644 --- a/src/graph/backend/dnnl/op_executable.cpp +++ b/src/graph/backend/dnnl/op_executable.cpp @@ -328,9 +328,7 @@ matmul_executable_t::desc_t matmul_executable_t::create_desc( pd_cache.at(op.get())); return {pd, true}; } - bool can_use_blocked_layout = true; - if (p_engine.get_kind() == dnnl::engine::kind::gpu) - can_use_blocked_layout = mgr.get_use_blocked_layout(); + const bool can_use_blocked_layout = mgr.get_use_blocked_layout(); dnnl::primitive_attr prm_attr; if (op->has_attr(op_attr::fusion_info_key) && op->get_attr(op_attr::fusion_info_key) != -1) { @@ -377,12 +375,11 @@ matmul_executable_t::desc_t matmul_executable_t::create_desc( op->get_input_value(1)->get_logical_tensor()) .is_constant() && is_constant_cache_enabled(p_engine); - const bool use_strided_wei = !const_weight - && (wei.get_ndims() == 4 - && (is_format(wei, dnnl::memory::format_tag::adbc) - || is_format(wei, dnnl::memory::format_tag::abdc) - || is_format(wei, dnnl::memory::format_tag::acbd))); - if (can_use_blocked_layout && !use_strided_wei) { + const bool use_strided_wei = wei.get_ndims() == 4 + && (is_format(wei, dnnl::memory::format_tag::adbc) + || is_format(wei, dnnl::memory::format_tag::abdc) + || is_format(wei, dnnl::memory::format_tag::acbd)); + if (const_weight || (can_use_blocked_layout && !use_strided_wei)) { wei = to_format_any(wei); } auto dst = make_dnnl_memory_desc( diff --git a/src/graph/backend/dnnl/op_executable.hpp b/src/graph/backend/dnnl/op_executable.hpp index 9b9b0fb920f..4f4ef1847ae 100644 --- a/src/graph/backend/dnnl/op_executable.hpp +++ b/src/graph/backend/dnnl/op_executable.hpp @@ -223,12 +223,15 @@ struct memory_reparser_t : public dummy_impl_t { void execute(const stream &stream, const std::unordered_map &args) const override { - if (args.find(DNNL_ARG_FROM)->second.get_data_handle() - == args.find(DNNL_ARG_TO)->second.get_data_handle()) + auto from = args.find(DNNL_ARG_FROM); + auto to = args.find(DNNL_ARG_TO); + if (from == args.end() || to == args.end()) return; + + if (from->second.get_data_handle() == to->second.get_data_handle()) dummy_impl_t::execute(stream, args); else { - const memory &dst_mem = args.find(DNNL_ARG_TO)->second; - const memory &src_mem = args.find(DNNL_ARG_FROM)->second; + const memory &dst_mem = to->second; + const memory &src_mem = from->second; const memory temp_mem = make_dnnl_memory(dst_mem.get_desc(), src_mem.get_engine(), src_mem.get_data_handle()); dnnl::reorder(temp_mem, dst_mem) @@ -241,12 +244,15 @@ struct memory_reparser_t : public dummy_impl_t { ::sycl::event execute_sycl(const stream &stream, const std::unordered_map &args, const std::vector<::sycl::event> &deps = {}) const override { - if (args.find(DNNL_ARG_FROM)->second.get_data_handle() - == args.find(DNNL_ARG_TO)->second.get_data_handle()) + auto from = args.find(DNNL_ARG_FROM); + auto to = args.find(DNNL_ARG_TO); + if (from == args.end() || to == args.end()) return {}; + + if (from->second.get_data_handle() == to->second.get_data_handle()) return dummy_impl_t::execute_sycl(stream, args, deps); else { - const memory &src_mem = args.find(DNNL_ARG_FROM)->second; - const memory &dst_mem = args.find(DNNL_ARG_TO)->second; + const memory &src_mem = from->second; + const memory &dst_mem = to->second; auto sycl_queue = dnnl::sycl_interop::get_queue(stream); auto e = sycl_queue.memcpy(dst_mem.get_data_handle(), src_mem.get_data_handle(), dst_mem.get_desc().get_size()); @@ -259,12 +265,15 @@ struct memory_reparser_t : public dummy_impl_t { cl_event execute_ocl(const stream &stream, const std::unordered_map &args, const std::vector &deps = {}) const override { - if (args.find(DNNL_ARG_FROM)->second.get_data_handle() - == args.find(DNNL_ARG_TO)->second.get_data_handle()) + auto from = args.find(DNNL_ARG_FROM); + auto to = args.find(DNNL_ARG_TO); + if (from == args.end() || to == args.end()) return {}; + + if (from->second.get_data_handle() == to->second.get_data_handle()) return dummy_impl_t::execute_ocl(stream, args, deps); else { - const memory &src_mem = args.find(DNNL_ARG_FROM)->second; - const memory &dst_mem = args.find(DNNL_ARG_TO)->second; + const memory &src_mem = from->second; + const memory &dst_mem = to->second; assert(deps.size() <= 1); // Passing the empty event to memcpy below causes failure. const bool empty = deps.size() == 0 || deps[0] == 0; diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index 1856403eaf5..986049945ca 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -815,7 +815,8 @@ status_t fuse_post_ops(std::shared_ptr &sg) { bool not_fusible = (!pops_fusible_map.at(base_op_kind).count(post_op_kind)) || (post_op_kind == op_kind::dnnl_binary - && !post_binary_fusible(op, &post_op)) + && !post_binary_fusible( + op, &post_op, sg->get_engine_kind())) || (post_op_kind == op_kind::dnnl_convolution && !post_depthwise_conv_fusible(op, &post_op)); if (not_fusible) { return status::success; } @@ -3666,6 +3667,46 @@ impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr &sg) { return impl::status::success; } +impl::status_t fuse_reshape_for_gqa(std::shared_ptr &sg) { + std::vector reshape_ops; + dnnl_dim_t head_num; + for (auto &cur_op : sg->get_ops()) { + auto in = cur_op->get_input_value(0)->get_logical_tensor(); + auto out = cur_op->get_output_value(0)->get_logical_tensor(); + if (cur_op->get_kind() == op_kind::dnnl_reshape) { + if (ltw(in).ndims() == 5 || ltw(out).ndims() == 5) { + reshape_ops.emplace_back(cur_op); + if (ltw(in).ndims() == 5) head_num = ltw(out).vdims()[1]; + } + } + } + + subgraph_rewriter_t rewriter(sg); + for (auto &reshape_op : reshape_ops) { + auto in = reshape_op->get_input_value(0)->get_logical_tensor(); + auto out = reshape_op->get_output_value(0)->get_logical_tensor(); + if (ltw(in).ndims() == 5) + rewriter.fuse_op_to_predecessor(reshape_op->shared_from_this()); + if (ltw(out).ndims() == 5) { + auto in_dims = ltw(in).vdims(); + // set the dim to 1 to ensure the shape infer can be passed. + // eg:[32,16,384,64]*[32,2,384,64] -> [32,16,384,64]*[32,1,384,64] + if (in_dims[1] != head_num) in_dims[1] = 1; + reshape_op->get_input_value(0)->set_dims(in_dims); + rewriter.fuse_op_to_successor(reshape_op->shared_from_this()); + } + } + rewriter.run(); + + //rewrite the subgraph internal logical_tensor's shape + for (auto &cur_op : sg->get_ops()) { + auto out_val = cur_op->get_output_value(0); + //the subgraph output logical tensor don't change shape. + if (!out_val->get_consumers().empty()) out_val->set_ndims(-1); + } + return infer_shape(sg); +} + impl::status_t swap_relu_mul_scales(std::shared_ptr &sg) { while (true) { std::vector> to_be_swapped; diff --git a/src/graph/backend/dnnl/passes/transform.hpp b/src/graph/backend/dnnl/passes/transform.hpp index 36a1383d3a6..4378c527e20 100644 --- a/src/graph/backend/dnnl/passes/transform.hpp +++ b/src/graph/backend/dnnl/passes/transform.hpp @@ -200,6 +200,9 @@ impl::status_t fuse_src_transpose_to_matmul(std::shared_ptr &sg); // the operator after transpose need a dense layout impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr &sg); +// This pass will fuse all the reshape to its lead op for GQA. +impl::status_t fuse_reshape_for_gqa(std::shared_ptr &sg); + // This pass will fold add_zps into the previous sub_zps with new_zps = sub_zps // - add_zps impl::status_t fold_sub_zps_add_zps(std::shared_ptr &sg); diff --git a/src/graph/backend/dnnl/passes/utils.cpp b/src/graph/backend/dnnl/passes/utils.cpp index a29eae6a99a..19a56954e9a 100644 --- a/src/graph/backend/dnnl/passes/utils.cpp +++ b/src/graph/backend/dnnl/passes/utils.cpp @@ -267,9 +267,11 @@ bool binary_doable( return true; } +// TODO: ekind can be removed once CPU optimized 5d tensor MatMul with +// broadcasted post op static bool post_binary_fusible_impl(const op_t *base_op, const std::vector &fused_shape, - const std::vector &other_shape) { + const std::vector &other_shape, engine_kind_t ekind) { assertm(fused_shape.size() == other_shape.size(), "must have same ndims, pls run binary_canonicalization pass first"); // full tensor and per tensor broadcasted @@ -278,9 +280,14 @@ static bool post_binary_fusible_impl(const op_t *base_op, [](dim_t i) { return i == 1; })) return true; - // any broadcasted for 4d tensor MatMul int32_t output_ndims = static_cast(fused_shape.size()); - if (base_op->get_kind() == op_kind::dnnl_matmul && output_ndims == 4) { + // 5d tensor MatMul with broadcasted post was not optimized on CPU + if (ekind == dnnl_cpu && base_op->get_kind() == op_kind::dnnl_matmul + && output_ndims == 5) + return false; + // any broadcasted for 4d or 5d tensor MatMul + if (base_op->get_kind() == op_kind::dnnl_matmul + && (output_ndims == 4 || output_ndims == 5)) { for (int32_t i = output_ndims - 1; i >= 0; i--) { if (other_shape[i] == 1) continue; if (fused_shape[i] != other_shape[i]) { return false; } @@ -367,7 +374,8 @@ std::pair> shuffle_fusible( return {true, {c_over_g_pos, groups}}; } -bool post_binary_fusible(const op_t *base_op, const op_t *bin_op) { +bool post_binary_fusible( + const op_t *base_op, const op_t *bin_op, graph::engine_kind_t ekind) { auto fused_out = base_op->get_output_values()[0]; auto consumers = fused_out->get_consumers(); if (consumers.size() != 1) return false; @@ -395,7 +403,7 @@ bool post_binary_fusible(const op_t *base_op, const op_t *bin_op) { } return post_binary_fusible_impl( - base_op, ltw(fused_in).vdims(), ltw(other_in).vdims()); + base_op, ltw(fused_in).vdims(), ltw(other_in).vdims(), ekind); } bool post_depthwise_conv_fusible( diff --git a/src/graph/backend/dnnl/passes/utils.hpp b/src/graph/backend/dnnl/passes/utils.hpp index 9b5cd3bfb1c..d768c735543 100644 --- a/src/graph/backend/dnnl/passes/utils.hpp +++ b/src/graph/backend/dnnl/passes/utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2021-2023 Intel Corporation +* Copyright 2021-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -307,7 +307,8 @@ std::pair> shuffle_fusible( // performance. So, we check the shape in this function and only make // per_tensor, per_channel, per_mb_w(MatMul) and full tensor broadcast // binary able to be fused. -bool post_binary_fusible(const op_t *base_op, const op_t *bin_op); +bool post_binary_fusible(const op_t *base_op, const op_t *bin_op, + engine_kind_t ekind = engine_kind::cpu); // oneDNN support post depthwise conv fusion. This function is used to check if // two conv ops can be fused as a conv + depthwise pattern. diff --git a/src/graph/backend/dnnl/patterns/sdp.cpp b/src/graph/backend/dnnl/patterns/sdp.cpp index 39cd04ee11f..f2968d8b2db 100644 --- a/src/graph/backend/dnnl/patterns/sdp.cpp +++ b/src/graph/backend/dnnl/patterns/sdp.cpp @@ -140,6 +140,50 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_fusion) return std::make_shared>(); }); +DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_fusion) + .set_priority(21.1f) + .set_kind(partition_kind_t::sdp) + .set_attr("FCreatePattern", + [](const std::shared_ptr &pgraph) -> void { + auto reshape1 + = pgraph->append_op(graph::op_kind::StaticReshape); + auto reshape2 + = pgraph->append_op(graph::op_kind::StaticReshape); + auto matmul_qk = pgraph->append_op(graph::op_kind::MatMul, + {{in_edge(0, reshape1, 0), + in_edge(1, reshape2, 0)}}); + auto fscore_scale = pgraph->append_alternation( + {graph::op_kind::Divide, graph::op_kind::Multiply}, + {in_edge(0, matmul_qk, 0)}); + auto optional_mask = std::make_shared(); + auto mask_reshape = optional_mask->append_op( + graph::op_kind::StaticReshape); + auto fscore_add = optional_mask->append_op( + graph::op_kind::Add, {in_edge(1, mask_reshape, 0)}); + optional_mask->create_input_port(0, fscore_add, 0); + optional_mask->create_output_port(0, fscore_add, 0); + auto mask = pgraph->append_optional( + optional_mask, {in_edge(0, fscore_scale, 0)}); + + // Optional select for distilbert + auto p_select2 = optional_select(pgraph, mask, 2); + auto softmax = pgraph->append_op(graph::op_kind::SoftMax, + {in_edge(0, p_select2, 0)}); + auto reshape3 + = pgraph->append_op(graph::op_kind::StaticReshape); + auto matmul_v = pgraph->append_op(graph::op_kind::MatMul, + {in_edge(0, softmax, 0), in_edge(1, reshape3, 0)}); + auto reshape4 + = pgraph->append_op(graph::op_kind::StaticReshape, + {in_edge(0, matmul_v, 0)}); + + // Optional transpose + reshape/reorder + optional_transpose_reshape(pgraph, reshape4, 0); + }) + .set_attr("FCreateKernel", []() -> kernel_ptr { + return std::make_shared>(); + }); + DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_jax_fusion) .set_priority(21.0f) .set_kind(partition_kind_t::sdp) diff --git a/src/graph/backend/dnnl/patterns/softmax_post_ops.cpp b/src/graph/backend/dnnl/patterns/softmax_post_ops.cpp index 3a3a312e2e0..6bea2c04ed9 100644 --- a/src/graph/backend/dnnl/patterns/softmax_post_ops.cpp +++ b/src/graph/backend/dnnl/patterns/softmax_post_ops.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2023 Intel Corporation +* Copyright 2022-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -76,7 +76,6 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, fp_softmax_post_ops) auto q_graph = std::make_shared(); pm::pb_op_t *pquantize = q_graph->append_op(graph::op_kind::Quantize); - pquantize->append_decision_function(check_zps_values<0>); q_graph->create_input_port(0, pquantize, 0); q_graph->create_output_port(0, pquantize, 0); pgraph->append_optional( diff --git a/src/graph/backend/dnnl/thread_local_cache.hpp b/src/graph/backend/dnnl/thread_local_cache.hpp index c43920fd98b..a91a2ac1c5e 100644 --- a/src/graph/backend/dnnl/thread_local_cache.hpp +++ b/src/graph/backend/dnnl/thread_local_cache.hpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2021-2023 Intel Corporation + * Copyright 2021-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -87,39 +87,43 @@ class thread_local_cache_t { // Clear the cached values in current thread void clear() { - cache_type_t &cache = get_thread_local_cache(); - for (auto &it : cache.data()) { - std::shared_ptr value = it.second.lock(); - if (value) { - std::lock_guard lock( - global_cache_type_t::get_global_cache()->mutex()); - auto &data = global_cache_type_t::get_global_cache()->data(); - - auto ret = data.find(it.first); - if (ret != data.end()) { - std::vector> &thread_instances - = ret->second; - auto pos = std::find_if(thread_instances.begin(), - thread_instances.end(), - [&](std::shared_ptr &ins) -> bool { - return ins.get() == value.get(); - }); - assertm(pos != thread_instances.end(), - "expected value to exist in cache"); - thread_instances.erase(pos); + cache_type_t &lcache = get_thread_local_cache(); + global_cache_type_t *gcache = global_cache_type_t::get_global_cache(); + // for safety purpose. it should not be nullptr. + if (gcache) { + for (auto &it : lcache.data()) { + std::shared_ptr value = it.second.lock(); + if (value) { + std::lock_guard lock(gcache->mutex()); + auto &data = gcache->data(); + + auto ret = data.find(it.first); + if (ret != data.end()) { + std::vector> &thread_instances + = ret->second; + auto pos = std::find_if(thread_instances.begin(), + thread_instances.end(), + [&](std::shared_ptr &ins) -> bool { + return ins.get() == value.get(); + }); + assertm(pos != thread_instances.end(), + "expected value to exist in cache"); + thread_instances.erase(pos); + } } } } - cache.data().clear(); + lcache.data().clear(); } // Remove the cached values for the given key in ALL threads void remove_if_exist(const size_t &key) { - std::lock_guard lock( - global_cache_type_t::get_global_cache()->mutex()); - auto pos = global_cache_type_t::get_global_cache()->data().find(key); - if (pos != global_cache_type_t::get_global_cache()->data().end()) { - pos->second.clear(); + global_cache_type_t *gcache = global_cache_type_t::get_global_cache(); + // for safety purpose. it should not be nullptr. + if (gcache) { + std::lock_guard lock(gcache->mutex()); + auto pos = gcache->data().find(key); + if (pos != gcache->data().end()) { pos->second.clear(); } } } @@ -136,17 +140,16 @@ class thread_local_cache_t { // be shared between threads std::shared_ptr ins = creator(); { - std::lock_guard lock( - global_cache_type_t::get_global_cache()->mutex()); - if (global_cache_type_t::get_global_cache()->data().count( - key)) { - global_cache_type_t::get_global_cache() - ->data() - .at(key) - .emplace_back(ins); - } else { - global_cache_type_t::get_global_cache()->data().emplace( - key, std::vector> {ins}); + auto *gcache = global_cache_type_t::get_global_cache(); + // for safety purpose. it should not be nullptr. + if (gcache) { + std::lock_guard lock(gcache->mutex()); + if (gcache->data().count(key)) { + gcache->data().at(key).emplace_back(ins); + } else { + gcache->data().emplace( + key, std::vector> {ins}); + } } } cache.data()[key] = ins; @@ -155,9 +158,15 @@ class thread_local_cache_t { } // This function increments the reference count - void retain() { global_cache_type_t::get_global_cache()->retain(); } + void retain() { + auto *gcache = global_cache_type_t::get_global_cache(); + if (gcache) gcache->retain(); + } - void release() { global_cache_type_t::get_global_cache()->release(); } + void release() { + auto *gcache = global_cache_type_t::get_global_cache(); + if (gcache) gcache->release(); + } private: class global_cache_type_t { @@ -172,10 +181,14 @@ class thread_local_cache_t { static global_cache_type_t *get_global_cache() { // A global table to store cached values in ALL threads. This global // table takes the ownership of cached values - static auto global_cache = std::shared_ptr( - new global_cache_type_t {}, - [](global_cache_type_t *ptr) { return ptr->release(); }); - return global_cache.get(); + try { + static auto global_cache = std::shared_ptr( + new global_cache_type_t {}, + [](global_cache_type_t *ptr) { + return ptr->release(); + }); + return global_cache.get(); + } catch (...) { return nullptr; } } // This function increments the reference count diff --git a/src/graph/backend/fake/fake_backend.cpp b/src/graph/backend/fake/fake_backend.cpp index 62b40a1a574..f68e85c8eb2 100644 --- a/src/graph/backend/fake/fake_backend.cpp +++ b/src/graph/backend/fake/fake_backend.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2021-2023 Intel Corporation + * Copyright 2021-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,9 +37,10 @@ graph::pass::pass_registry_t fake_backend_t::pass_registry_ } // namespace fake_impl // This function must be called by backend_registry_t -void register_fake_backend() { - backend_registry_t::get_singleton().register_backend( +status_t register_fake_backend() { + const status_t ret = backend_registry_t::get_singleton().register_backend( &fake_impl::fake_backend_t::get_singleton()); + return ret; } } // namespace graph diff --git a/src/graph/backend/graph_compiler/compiler_backend.cpp b/src/graph/backend/graph_compiler/compiler_backend.cpp index 70ace35277f..7da5c3f4e75 100644 --- a/src/graph/backend/graph_compiler/compiler_backend.cpp +++ b/src/graph/backend/graph_compiler/compiler_backend.cpp @@ -110,9 +110,10 @@ status_t compiler_backend_t::get_partitions( } // namespace compiler_impl -void register_compiler_backend() { - backend_registry_t::get_singleton().register_backend( +status_t register_compiler_backend() { + const status_t ret = backend_registry_t::get_singleton().register_backend( &compiler_impl::compiler_backend_t::get_singleton()); + return ret; } } // namespace graph diff --git a/src/graph/interface/backend.hpp b/src/graph/interface/backend.hpp index 5d3524d128b..3ae8705a842 100644 --- a/src/graph/interface/backend.hpp +++ b/src/graph/interface/backend.hpp @@ -38,11 +38,11 @@ namespace impl { namespace graph { // forward declaration -void register_dnnl_backend(); -void register_fake_backend(); +status_t register_dnnl_backend(); +status_t register_fake_backend(); #ifdef DNNL_ENABLE_COMPILER_BACKEND // register graph compiler backend -void register_compiler_backend(); +status_t register_compiler_backend(); #endif class backend_t { @@ -131,10 +131,11 @@ class backend_registry_t { return inst; } - backend_t *register_backend(const backend_t *abackend) { + status_t register_backend(const backend_t *abackend) { auto has_colliding_name = [&](const backend_t *backend) { return backend->get_name().compare(abackend->get_name()) == 0; }; + auto backend_already_registered = [&]() { return std::find_if(sorted_backends_.begin(), sorted_backends_.end(), has_colliding_name) @@ -145,10 +146,7 @@ class backend_registry_t { return l->get_priority() > r->get_priority(); }; - if (backend_already_registered()) { - throw std::runtime_error( - "backend name not unique: " + abackend->get_name()); - } + if (backend_already_registered()) { return status::runtime_error; } std::lock_guard lock(m_); @@ -156,7 +154,7 @@ class backend_registry_t { sorted_backends_.emplace_back(abackend); std::sort(sorted_backends_.begin(), sorted_backends_.end(), compare_priority); - return const_cast(abackend); + return status::success; } // This interface will firstly register all available backends and then diff --git a/src/xpu/ocl/stream_impl.cpp b/src/xpu/ocl/stream_impl.cpp index 54bc32df317..57c71405b3d 100644 --- a/src/xpu/ocl/stream_impl.cpp +++ b/src/xpu/ocl/stream_impl.cpp @@ -221,6 +221,12 @@ status_t stream_impl_t::fill(impl::stream_t *stream, return status::success; } +status_t stream_impl_t::barrier() { + cl_int err = clEnqueueMarkerWithWaitList(queue(), 0, nullptr, nullptr); + OCL_CHECK(err); + return status::success; +} + const xpu::ocl::context_t &stream_impl_t::ocl_ctx() const { static xpu::ocl::context_t empty_ctx {}; return ctx_.get(empty_ctx); diff --git a/src/xpu/ocl/stream_impl.hpp b/src/xpu/ocl/stream_impl.hpp index 6d28ee0c8f3..9c22212ad39 100644 --- a/src/xpu/ocl/stream_impl.hpp +++ b/src/xpu/ocl/stream_impl.hpp @@ -67,6 +67,8 @@ class stream_impl_t : public impl::stream_impl_t { xpu::event_t &out_dep, xpu::stream_profiler_t *stream_profiler = nullptr); + status_t barrier(); + const xpu::ocl::context_t &ocl_ctx() const; xpu::ocl::context_t &ocl_ctx(); xpu::context_t &ctx(); diff --git a/src/xpu/ocl/stream_profiler.cpp b/src/xpu/ocl/stream_profiler.cpp index c3b58b2b787..7e067715574 100644 --- a/src/xpu/ocl/stream_profiler.cpp +++ b/src/xpu/ocl/stream_profiler.cpp @@ -16,7 +16,7 @@ #include -#include +#include #include #include "common/c_types_map.hpp" @@ -43,7 +43,7 @@ status_t stream_profiler_t::get_info(profiling_data_kind_t data_kind, return status::success; } - std::unordered_map stamp2entry; + std::map stamp2entry; for (auto &ev : events_) { auto &entry = stamp2entry[ev.stamp]; const xpu::ocl::event_t &ocl_event diff --git a/src/xpu/stream_profiler.hpp b/src/xpu/stream_profiler.hpp index 0fc8678046d..77356f5e887 100644 --- a/src/xpu/stream_profiler.hpp +++ b/src/xpu/stream_profiler.hpp @@ -19,9 +19,9 @@ #include #include +#include #include #include -#include #include "common/c_types_map.hpp" @@ -70,6 +70,11 @@ struct stream_profiler_t { m_.unlock(); } + // The contract is profiler interfaces are called only in between + // `start_profiling` and `stop_profiling`, which provide a secure + // multi-threaded access because of the lock. It allows to strip the lock + // from all other calls, e.g., `stamp, or `register_event` (except `reset`) + // to reduce the overhead for profiling. void start_profiling() { m_.lock(); stamp_++; @@ -86,8 +91,7 @@ struct stream_profiler_t { } protected: - status_t get_info_impl( - const std::unordered_map &stamp2entry, + status_t get_info_impl(const std::map &stamp2entry, profiling_data_kind_t data_kind, uint64_t *data) const { int idx = 0; for (auto &kv : stamp2entry) { diff --git a/src/xpu/sycl/stream_impl.cpp b/src/xpu/sycl/stream_impl.cpp index 7db7eaa3a2d..69b5a84e062 100644 --- a/src/xpu/sycl/stream_impl.cpp +++ b/src/xpu/sycl/stream_impl.cpp @@ -175,6 +175,11 @@ status_t stream_impl_t::fill(const memory_storage_t &dst, uint8_t pattern, return status::success; } +status_t stream_impl_t::barrier() { + queue()->ext_oneapi_submit_barrier(); + return status::success; +} + const xpu::sycl::context_t &stream_impl_t::sycl_ctx() const { static xpu::sycl::context_t empty_ctx {}; return ctx_.get(empty_ctx); diff --git a/src/xpu/sycl/stream_impl.hpp b/src/xpu/sycl/stream_impl.hpp index b9f29840333..81f7a531809 100644 --- a/src/xpu/sycl/stream_impl.hpp +++ b/src/xpu/sycl/stream_impl.hpp @@ -63,6 +63,8 @@ class stream_impl_t : public impl::stream_impl_t { const xpu::event_t &deps, xpu::event_t &out_dep, xpu::stream_profiler_t *stream_profiler = nullptr); + status_t barrier(); + const xpu::sycl::context_t &sycl_ctx() const; xpu::sycl::context_t &sycl_ctx(); diff --git a/src/xpu/sycl/stream_profiler.cpp b/src/xpu/sycl/stream_profiler.cpp index a192cce6ed7..6fdc1db8607 100644 --- a/src/xpu/sycl/stream_profiler.cpp +++ b/src/xpu/sycl/stream_profiler.cpp @@ -14,7 +14,7 @@ * limitations under the License. *******************************************************************************/ -#include +#include #include #include "common/c_types_map.hpp" @@ -41,7 +41,7 @@ status_t stream_profiler_t::get_info(profiling_data_kind_t data_kind, return status::success; } - std::unordered_map stamp2entry; + std::map stamp2entry; for (auto &ev : events_) { const xpu::sycl::event_t &sycl_event = *utils::downcast(ev.event.get()); diff --git a/tests/benchdnn/dnn_types.cpp b/tests/benchdnn/dnn_types.cpp index fd199d98a12..1f12839baf1 100644 --- a/tests/benchdnn/dnn_types.cpp +++ b/tests/benchdnn/dnn_types.cpp @@ -214,7 +214,7 @@ int attr_t::policy2mask(int arg, policy_t policy, || policy == policy_t::COMMON) return attr_t::get_default_mask(policy); - if (ndims <= 0) SAFE_V(FAIL); + if (ndims < 2) SAFE_V(FAIL); switch (policy) { case PER_DIM_1: case PER_OC: return (1 << (ndims - 1)); @@ -227,7 +227,7 @@ int attr_t::policy2mask(int arg, policy_t policy, // PER_OC assert(policy == policy_t::PER_OC); - if (ndims <= 0) SAFE_V(FAIL); + if (ndims < 1) SAFE_V(FAIL); return 1 << (ndims - 1); } else { // Default case @@ -779,7 +779,7 @@ std::ostream &operator<<(std::ostream &s, dnnl_accumulation_mode_t am) { std::ostream &operator<<(std::ostream &s, const attr_t::rounding_mode_t &rm) { std::string sep; - for (auto i : rm.rounding_modes_) { + for (const auto &i : rm.rounding_modes_) { s << sep << arg2str(i.first) << ":" << rounding_mode2str(i.second); if (rm.is_set_seed) s << ":" << rm.seed; sep = "+"; @@ -1122,7 +1122,7 @@ dnnl_primitive_attr_t create_dnnl_attr( } if (!attr.rounding_mode.is_def()) { - for (const auto e : attr.rounding_mode.rounding_modes_) { + for (const auto &e : attr.rounding_mode.rounding_modes_) { DNN_SAFE_V(dnnl_primitive_attr_set_rounding( dnnl_attr, e.first, e.second)); } diff --git a/tests/benchdnn/graph/input_displacer.cpp b/tests/benchdnn/graph/input_displacer.cpp index 88a68b6f71a..4c6f5876ce8 100644 --- a/tests/benchdnn/graph/input_displacer.cpp +++ b/tests/benchdnn/graph/input_displacer.cpp @@ -109,6 +109,25 @@ partition_data_displacer_t::partition_data_displacer_t( filling_type_t::quantization)); break; } + + if (parent_op->kind_ == "StaticReshape") { + // StaticReshape is accepted when the pattern is + // "StaticReshape + Matmul" and it doesn't have any + // predecessors in the partition + const auto &parent_op_in_lt = parent_op->in_lts_[0]; + const auto &prev_parent_op + = dg_->get_op_by_out_lt(parent_op_in_lt.id_); + if (prev_parent_op.empty() + || op_ids_set_.find(prev_parent_op.id_) + == op_ids_set_.end()) { + if (aop.kind_ == "MatMul") { + quantize_displace_.emplace(parent_op_in_lt.id_, + std::make_tuple(aop, i, parent_op_in_lt, + filling_type_t::quantization)); + } + break; + } + } } // Continue only on allowed ops. if (go_through_op_kind.find(parent_op->kind_) @@ -264,11 +283,22 @@ int partition_data_displacer_t::displace_input_data( is_grouped_conv = groups > 1; } - bool mds_ok = IMPLICATION(!mds_are_equal, mds_are_int8 || is_grouped_conv); + bool is_reshaped_dims = mem_replace.nelems() == mem.nelems() + && mem_replace.ndims() != mem.ndims(); + + bool mds_ok = IMPLICATION(!mds_are_equal, + mds_are_int8 || is_grouped_conv || is_reshaped_dims); SAFE(mds_ok ? OK : FAIL, WARN); + + dnnl_memory_desc_t md = mem.md_; + if (is_reshaped_dims) { + DNN_SAFE_V(dnnl_memory_desc_create_with_strides( + &md, mem.ndims(), mem.dims(), mem_replace.dt(), mem.strides())); + } dnnl_memory_desc_destroy(mem_replace.md_); - dnnl_memory_desc_clone(&mem_replace.md_, mem.md_); + dnnl_memory_desc_clone(&mem_replace.md_, md); SAFE(mem.reorder(mem_replace), WARN); + if (is_reshaped_dims) dnnl_memory_desc_destroy(md); return OK; } @@ -279,6 +309,7 @@ int partition_data_displacer_t::gen_quantize_filling( ::graph::deserialized_op op = main_op; auto driver = opkind2driver(opstr2kind(op.kind_)); bool is_f8_quantization = (dt == "f8_e5m2" || dt == "f8_e4m3"); + bool is_f16 = dt == "f16"; op.in_lts_[0].data_type_ = dt; if (op.in_lts_.size() > 1) { @@ -314,9 +345,10 @@ int partition_data_displacer_t::gen_quantize_filling( } ::std::unordered_set empty_set; - // As f8 support status is limited now, use tset engine to ensure that - // primitive can be created and generate data - const auto &eng = is_f8_quantization ? get_test_engine() : get_cpu_engine(); + // As f8 and f16 support status is limited now, use test engine to ensure + // that primitive can be created and generate data + const auto &eng = is_f8_quantization || is_f16 ? get_test_engine() + : get_cpu_engine(); ref_primitive_t ref_prim(op); ref_prim.init_prb(res); diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all index 921000b02b9..9e5882f8cf2 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all @@ -30,6 +30,8 @@ --reset --case=complex_fusion/mha/sdpa-plain-wo-scale-bf16-bs1.json --reset --case=complex_fusion/mha/sdpa-plain-wo-scale-fp32-bs1.json --reset --case=complex_fusion/mha/sdpa-plain-wo-scale-int8-bs1.json +--reset --case=complex_fusion/mha/GQA-fp32.json +--reset --case=complex_fusion/mha/GQA-fp16.json # Rewrited graphs --reset --in-shapes=4:4x16x32x256+5:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-fp32-bs1.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci index 25cd8ccc996..f67f5b73971 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci @@ -21,3 +21,5 @@ --reset --case=complex_fusion/mha/sdpa-plain-wo-scale-bf16-bs1.json --reset --case=complex_fusion/mha/sdpa-plain-wo-scale-fp32-bs1.json --reset --case=complex_fusion/mha/sdpa-plain-wo-scale-int8-bs1.json +--reset --case=complex_fusion/mha/GQA-fp32.json +--reset --case=complex_fusion/mha/GQA-fp16.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/GQA-fp16.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/GQA-fp16.json new file mode 100644 index 00000000000..6c7472479a9 --- /dev/null +++ b/tests/benchdnn/inputs/graph/complex_fusion/mha/GQA-fp16.json @@ -0,0 +1,686 @@ +{ + "version": "3.6.0", + "engine_kind": "cpu", + "fpmath_mode": "strict", + "input_ports": [ + 0, + 2, + 8, + 11, + 18 + ], + "output_ports": [ + 23 + ], + "graph": [ + { + "id": 5, + "name": "reshape1", + "kind": "StaticReshape", + "attrs": { + "special_zero": { + "type": "bool", + "value": 0 + }, + "shape": { + "type": "s64[]", + "value": [ + 32, + 2, + 8, + 384, + 64 + ] + } + }, + "inputs": [ + { + "id": 0, + "dtype": "f16", + "shape": [ + 32, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 1, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 64 + ], + "stride": [ + 393216, + 196608, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 6, + "name": "reshape2", + "kind": "StaticReshape", + "attrs": { + "special_zero": { + "type": "bool", + "value": 0 + }, + "shape": { + "type": "s64[]", + "value": [ + 32, + 2, + 1, + 384, + 64 + ] + } + }, + "inputs": [ + { + "id": 2, + "dtype": "f16", + "shape": [ + 32, + 2, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 3, + "dtype": "f16", + "shape": [ + 32, + 2, + 1, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 7, + "name": "bmm1", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 1 + } + }, + "inputs": [ + { + "id": 1, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 64 + ], + "stride": [ + 393216, + 196608, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 3, + "dtype": "f16", + "shape": [ + 32, + 2, + 1, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 4, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 10, + "name": "scale_div", + "kind": "Divide", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 4, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 8, + "dtype": "f16", + "shape": [ + 1 + ], + "stride": [ + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 9, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 13, + "name": "reshape3", + "kind": "StaticReshape", + "attrs": { + "special_zero": { + "type": "bool", + "value": 0 + }, + "shape": { + "type": "s64[]", + "value": [ + 32, + 1, + 1, + 1, + 384 + ] + } + }, + "inputs": [ + { + "id": 11, + "dtype": "f16", + "shape": [ + 32, + 1, + 1, + 384 + ], + "stride": [ + 384, + 384, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 12, + "dtype": "f16", + "shape": [ + 32, + 1, + 1, + 1, + 384 + ], + "stride": [ + 384, + 384, + 384, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 15, + "name": "mask_add", + "kind": "Add", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 9, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 12, + "dtype": "f16", + "shape": [ + 32, + 1, + 1, + 1, + 384 + ], + "stride": [ + 384, + 384, + 384, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 14, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 17, + "name": "softmax", + "kind": "SoftMax", + "attrs": { + "axis": { + "type": "s64", + "value": -1 + } + }, + "inputs": [ + { + "id": 14, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 16, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 21, + "name": "reshape3", + "kind": "StaticReshape", + "attrs": { + "special_zero": { + "type": "bool", + "value": 0 + }, + "shape": { + "type": "s64[]", + "value": [ + 32, + 2, + 1, + 384, + 64 + ] + } + }, + "inputs": [ + { + "id": 18, + "dtype": "f16", + "shape": [ + 32, + 2, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 19, + "dtype": "f16", + "shape": [ + 32, + 2, + 1, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 22, + "name": "bmm2", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 0 + } + }, + "inputs": [ + { + "id": 16, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 19, + "dtype": "f16", + "shape": [ + 32, + 2, + 1, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 20, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 64 + ], + "stride": [ + 393216, + 196608, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 24, + "name": "reshape4", + "kind": "StaticReshape", + "attrs": { + "special_zero": { + "type": "bool", + "value": 0 + }, + "shape": { + "type": "s64[]", + "value": [ + 32, + 16, + 384, + 64 + ] + } + }, + "inputs": [ + { + "id": 20, + "dtype": "f16", + "shape": [ + 32, + 2, + 8, + 384, + 64 + ], + "stride": [ + 393216, + 196608, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 23, + "dtype": "f16", + "shape": [ + 32, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + } + ] +} \ No newline at end of file diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/GQA-fp32.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/GQA-fp32.json new file mode 100644 index 00000000000..c4bfa43e540 --- /dev/null +++ b/tests/benchdnn/inputs/graph/complex_fusion/mha/GQA-fp32.json @@ -0,0 +1,686 @@ +{ + "version": "3.6.0", + "engine_kind": "cpu", + "fpmath_mode": "strict", + "input_ports": [ + 0, + 2, + 8, + 11, + 18 + ], + "output_ports": [ + 23 + ], + "graph": [ + { + "id": 5, + "name": "reshape1", + "kind": "StaticReshape", + "attrs": { + "special_zero": { + "type": "bool", + "value": 0 + }, + "shape": { + "type": "s64[]", + "value": [ + 32, + 2, + 8, + 384, + 64 + ] + } + }, + "inputs": [ + { + "id": 0, + "dtype": "f32", + "shape": [ + 32, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 1, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 64 + ], + "stride": [ + 393216, + 196608, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 6, + "name": "reshape2", + "kind": "StaticReshape", + "attrs": { + "special_zero": { + "type": "bool", + "value": 0 + }, + "shape": { + "type": "s64[]", + "value": [ + 32, + 2, + 1, + 384, + 64 + ] + } + }, + "inputs": [ + { + "id": 2, + "dtype": "f32", + "shape": [ + 32, + 2, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 3, + "dtype": "f32", + "shape": [ + 32, + 2, + 1, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 7, + "name": "bmm1", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 1 + } + }, + "inputs": [ + { + "id": 1, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 64 + ], + "stride": [ + 393216, + 196608, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 3, + "dtype": "f32", + "shape": [ + 32, + 2, + 1, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 4, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 10, + "name": "scale_div", + "kind": "Divide", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 4, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 8, + "dtype": "f32", + "shape": [ + 1 + ], + "stride": [ + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 9, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 13, + "name": "reshape3", + "kind": "StaticReshape", + "attrs": { + "special_zero": { + "type": "bool", + "value": 0 + }, + "shape": { + "type": "s64[]", + "value": [ + 32, + 1, + 1, + 1, + 384 + ] + } + }, + "inputs": [ + { + "id": 11, + "dtype": "f32", + "shape": [ + 32, + 1, + 1, + 384 + ], + "stride": [ + 384, + 384, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 12, + "dtype": "f32", + "shape": [ + 32, + 1, + 1, + 1, + 384 + ], + "stride": [ + 384, + 384, + 384, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 15, + "name": "mask_add", + "kind": "Add", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 9, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 12, + "dtype": "f32", + "shape": [ + 32, + 1, + 1, + 1, + 384 + ], + "stride": [ + 384, + 384, + 384, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 14, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 17, + "name": "softmax", + "kind": "SoftMax", + "attrs": { + "axis": { + "type": "s64", + "value": -1 + } + }, + "inputs": [ + { + "id": 14, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 16, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 21, + "name": "reshape3", + "kind": "StaticReshape", + "attrs": { + "special_zero": { + "type": "bool", + "value": 0 + }, + "shape": { + "type": "s64[]", + "value": [ + 32, + 2, + 1, + 384, + 64 + ] + } + }, + "inputs": [ + { + "id": 18, + "dtype": "f32", + "shape": [ + 32, + 2, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 19, + "dtype": "f32", + "shape": [ + 32, + 2, + 1, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 22, + "name": "bmm2", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 0 + } + }, + "inputs": [ + { + "id": 16, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 384 + ], + "stride": [ + 2359296, + 1179648, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 19, + "dtype": "f32", + "shape": [ + 32, + 2, + 1, + 384, + 64 + ], + "stride": [ + 49152, + 24576, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 20, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 64 + ], + "stride": [ + 393216, + 196608, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 24, + "name": "reshape4", + "kind": "StaticReshape", + "attrs": { + "special_zero": { + "type": "bool", + "value": 0 + }, + "shape": { + "type": "s64[]", + "value": [ + 32, + 16, + 384, + 64 + ] + } + }, + "inputs": [ + { + "id": 20, + "dtype": "f32", + "shape": [ + 32, + 2, + 8, + 384, + 64 + ], + "stride": [ + 393216, + 196608, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 23, + "dtype": "f32", + "shape": [ + 32, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + } + ] +} diff --git a/tests/benchdnn/inputs/graph/pattern/harness_int8_all b/tests/benchdnn/inputs/graph/pattern/harness_int8_all index ef6bd90181c..d4fe03d5d0d 100644 --- a/tests/benchdnn/inputs/graph/pattern/harness_int8_all +++ b/tests/benchdnn/inputs/graph/pattern/harness_int8_all @@ -124,3 +124,4 @@ --reset --op-attrs=3:zps:1 --in-shapes=5:512 --case=pattern/int8/int8_lnorm_tc_multiply_quantize.json #softmax --reset --case=pattern/int8/int8_softmax_add.json +--reset --op-attrs=3:zps:32 --case=pattern/int8/int8_softmax_add.json diff --git a/tests/benchdnn/inputs/reorder/test_reorder_int4 b/tests/benchdnn/inputs/reorder/test_reorder_int4 index d51340137ae..de773950647 100644 --- a/tests/benchdnn/inputs/reorder/test_reorder_int4 +++ b/tests/benchdnn/inputs/reorder/test_reorder_int4 @@ -6,6 +6,7 @@ --dtag=abx,bax 2x64x64x3x3 2x56x56x3x3 --dtag=abx,bax 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3 --dtag=aBx16b 2x64x14x14 2x56x14x14 +--dtag=aCB16b32c 2x64x3 4x56x3 --dtag=gOIhw16i16o 2x64x64x3x3 2x56x56x3x3 --dtag=gOIhw8i16o2i 2x64x64x3x3 2x56x56x3x3 --dtag=gOIhw8o16i2o 2x64x64x3x3 2x56x56x3x3 @@ -21,6 +22,7 @@ --stag=bax,abx 2x64x64x3x3 2x56x56x3x3 --stag=bax,abx 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3 --stag=aBx16b 2x64x14x14 2x56x14x14 +--stag=aCB16b16c 2x64x3 4x56x3 --stag=gOIhw16i16o 2x64x64x3x3 2x56x56x3x3 --stag=gOIhw8i16o2i 2x64x64x3x3 2x56x56x3x3 --stag=gOIhw8o16i2o 2x64x64x3x3 2x56x56x3x3 diff --git a/tests/benchdnn/matmul/matmul.cpp b/tests/benchdnn/matmul/matmul.cpp index 57ebb8d8e03..e47748474fb 100644 --- a/tests/benchdnn/matmul/matmul.cpp +++ b/tests/benchdnn/matmul/matmul.cpp @@ -741,9 +741,6 @@ void skip_invalid_prb(const prb_t *prb, res_t *res) { void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, const args_t &ref_args) { - const auto dt = prb->get_dt(kind); - const float trh = dt == dnnl_f32 ? 1e-6f : epsilon_dt(dt); - cmp.set_threshold(trh); cmp.set_zero_trust_percent(90.f); // TODO: why so bad filling? } diff --git a/tests/benchdnn/pool/ref_pool.cpp b/tests/benchdnn/pool/ref_pool.cpp index d63a3394b5a..4139753ea11 100644 --- a/tests/benchdnn/pool/ref_pool.cpp +++ b/tests/benchdnn/pool/ref_pool.cpp @@ -38,8 +38,6 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) { // XXX: this is a hack to let tests with padded area to pass for bf16 // dt due to the library initialize values with -max_dt, but not -INF. float max_value = lowest_dt(prb->dst_dt()); - if (is_nvidia_gpu() || is_amd_gpu()) - max_value = lowest_dt(prb->src_dt()); float avg_value = 0.; // Set initial value based on ws data type int ws_off = prb->kernel_size() <= UINT8_MAX ? UINT8_MAX : INT_MAX; diff --git a/tests/benchdnn/utils/parser.cpp b/tests/benchdnn/utils/parser.cpp index 97239220d50..3e56e207cef 100644 --- a/tests/benchdnn/utils/parser.cpp +++ b/tests/benchdnn/utils/parser.cpp @@ -369,8 +369,7 @@ attr_t::dropout_t parse_attr_dropout_func(const std::string &s) { } if (start_pos != std::string::npos) { - subs = get_substr(s, start_pos, '\0'); - v.tag = subs; + v.tag = get_substr(s, start_pos, '\0'); if (check_tag(v.tag) != OK) { BENCHDNN_PRINT(0, "%s \'%s\' %s\n", "Error: dropout mask tag", diff --git a/tests/gtests/graph/unit/backend/dnnl/test_subgraph_pass.cpp b/tests/gtests/graph/unit/backend/dnnl/test_subgraph_pass.cpp index ebe04c6c694..6df85928f02 100644 --- a/tests/gtests/graph/unit/backend/dnnl/test_subgraph_pass.cpp +++ b/tests/gtests/graph/unit/backend/dnnl/test_subgraph_pass.cpp @@ -870,7 +870,7 @@ TEST_P(int8_matmul_with_diff_inputs_t, Int8MatmulPasses) { auto subgraph = std::make_shared( agraph.get_partitions()[0]->get_ops(), p_eng, fpmath_mode::strict, - false, true); + true, true); ASSERT_EQ(subgraph->get_ops().size(), 5U); dnnl_impl::check_with_bias(subgraph); @@ -992,7 +992,7 @@ TEST_P(matmul_with_diff_inputs_t, MatmulPasses) { auto subgraph = std::make_shared( agraph.get_partitions()[0]->get_ops(), p_eng, fpmath_mode::strict, - false, true); + true, true); ASSERT_EQ(subgraph->get_ops().size(), 2U); dnnl_impl::check_with_bias(subgraph); @@ -2329,6 +2329,8 @@ TEST(test_subgraph_pass_subgraph_pass, FuseNCXConvolutionBinaryAddNC11PostSrc) { subgraph->get_ops().end(), [](const std::shared_ptr &op) { return op->get_kind() == dnnl_impl::op_kind::dnnl_convolution; }); + ASSERT_NE(qconv_op, subgraph->get_ops().end()); + ASSERT_TRUE((*qconv_op)->has_attr(dnnl_impl::op_attr::fusion_info_key)); int64_t key = (*qconv_op)->get_attr( dnnl_impl::op_attr::fusion_info_key); auto &fusion_info = subgraph->fusion_info_mgr_.get_info(key); @@ -2483,6 +2485,8 @@ TEST(test_subgraph_pass_subgraph_pass, FuseNXCConvolutionBinaryAddNC11PostSrc) { subgraph->get_ops().end(), [](const std::shared_ptr &op) { return op->get_kind() == dnnl_impl::op_kind::dnnl_convolution; }); + ASSERT_NE(qconv_op, subgraph->get_ops().end()); + ASSERT_TRUE((*qconv_op)->has_attr(dnnl_impl::op_attr::fusion_info_key)); int64_t key = (*qconv_op)->get_attr( dnnl_impl::op_attr::fusion_info_key); auto &fusion_info = subgraph->fusion_info_mgr_.get_info(key); diff --git a/tests/gtests/graph/unit/interface/test_backend.cpp b/tests/gtests/graph/unit/interface/test_backend.cpp index 35d6006d54a..030d01f3ad1 100644 --- a/tests/gtests/graph/unit/interface/test_backend.cpp +++ b/tests/gtests/graph/unit/interface/test_backend.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2023 Intel Corporation +* Copyright 2022-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,5 +48,5 @@ TEST(test_interface_backend, CompareLogicalTensor) { TEST(test_interface_backend, RegisterBackend) { auto ®istry = graph::backend_registry_t::get_singleton(); auto bkds = registry.get_registered_backends(); - EXPECT_THROW(registry.register_backend(bkds[0]), std::runtime_error); + EXPECT_EQ(registry.register_backend(bkds[0]), graph::status::runtime_error); } diff --git a/tests/gtests/test_convolution_format_any.cpp b/tests/gtests/test_convolution_format_any.cpp index 8c5f9c90e1b..304de44f828 100644 --- a/tests/gtests/test_convolution_format_any.cpp +++ b/tests/gtests/test_convolution_format_any.cpp @@ -129,5 +129,7 @@ TEST_P(conv_any_fmt_test_float, TestsConvolutionAnyFmt) {} CPU_INSTANTIATE_TEST_SUITE_P(TestConvolutionAlexnetAnyFmtForward, conv_any_fmt_test_float, ::testing::Values(ALEXNET_SUITE(BLK))); +GPU_INSTANTIATE_TEST_SUITE_P(TestConvolutionAlexnetAnyFmtForward, + conv_any_fmt_test_float, ::testing::Values(ALEXNET_SUITE(BLK))); #endif } // namespace dnnl