Skip to content

Commit

Permalink
Use general reorder impl
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 18, 2024
1 parent 15053f7 commit a386ac2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

#include "emitters/snippets/cpu_runtime_configurator.hpp"

#include "memory_desc/cpu_blocked_memory_desc.h"
#include "memory_desc/cpu_memory_desc_utils.h"
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include "snippets/lowered/loop_manager.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "snippets/utils/utils.hpp"

#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.hpp"
namespace ov {
namespace intel_cpu {
Expand Down Expand Up @@ -137,25 +139,35 @@ void CPURuntimeConfigurator::update_requested_descs(const ov::snippets::lowered:
const auto& params = linear_ir->get_parameters();
OPENVINO_ASSERT(params.size() == m_in_num);
for (size_t i = 0; i < m_in_num; ++i) {
// TODO: remove
if (i != 1) continue;
const auto& param = params[i];
auto consumers = param->get_output_port_connector(0)->get_consumers();
const bool has_brgemm_consumers =
const bool brgemm_with_extracted_repacking =
std::any_of(consumers.begin(), consumers.end(), [](const ov::snippets::lowered::ExpressionPort& port) {
return ov::is_type<ov::intel_cpu::BrgemmCPU>(port.get_expr()->get_node());
auto brgemm = ov::as_type_ptr<ov::intel_cpu::BrgemmCPU>(port.get_expr()->get_node());
return brgemm && brgemm_utils::with_repacking(brgemm->get_type());
});
if (has_brgemm_consumers) {
const auto& shape = param->get_output_port_descriptor(0)->get_shape();
VectorDims normalized_dims(3, 1);
*normalized_dims.rbegin() = *shape.rbegin();
*++normalized_dims.rbegin() = *++shape.rbegin();
normalized_dims[0] = std::accumulate(shape.begin(), shape.end() - 2, static_cast<Dim>(1), std::multiplies<Dim>());

const auto data_type = DnnlExtensionUtils::ElementTypeToDataType(param->get_node()->get_output_element_type(0));
// TODO: tag must be selected based on Brgemm params (inner block + vnni factor?)
const auto tag = dnnl::memory::format_tag::aCB16b64c2b;
optimal_descs[i] = std::make_shared<DnnlBlockedMemoryDesc>(Shape(normalized_dims), data_type, tag);
if (brgemm_with_extracted_repacking) {
const auto& desc = param->get_output_port_descriptor(0);
const auto& shape = desc->get_shape();
const auto& K = *++shape.rbegin();
const auto& N = *shape.rbegin();

const auto& precision = param->get_node()->get_output_element_type(0);
const auto vnni_factor = brgemm_utils::compute_vnni_factor(precision);
const auto n_block = brgemm_utils::repacking::compute_inner_n_block(precision);
// Firstly, batch dims are set
VectorDims requested_blocked_shape(shape.begin(), shape.end() - m_config->tile_rank);
// Then, the blocked dims are formed
requested_blocked_shape.insert(
requested_blocked_shape.end(),
{snippets::utils::div_up(K, vnni_factor), snippets::utils::div_up(N, n_block), n_block, vnni_factor});
// Please note: only planar layout is supported for now
const VectorDims order{0, 1, 2, 3, 3, 2};
auto cpu_desc = std::make_shared<ov::intel_cpu::CpuBlockedMemoryDesc>(precision,
Shape(shape),
requested_blocked_shape,
order);
optimal_descs[i] = MemoryDescUtils::convertToDnnlMemoryDesc(cpu_desc);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,13 @@ bool pass::AdjustBrgemmCopyBLoopPorts::run(const snippets::lowered::LinearIR& li
};

for (const auto& expr : linear_ir) {
const auto& node = expr->get_node();
// TODO: except current CopyB->Buffer->BrgemmCPU sequence,
// Parameter->BrgemmCPU(with CopyB outside the Subgraph) sequence must be handled.
if (!is_type<BrgemmCopyB>(node))
const auto brgemm = ov::as_type_ptr<BrgemmCPU>(expr->get_node());
if (!brgemm || brgemm_utils::stand_alone(brgemm->get_type()))
continue;
const auto& loop_ids = expr->get_loop_ids();
const auto& child_ports = expr->get_output_port(0).get_connected_ports();
// Note: this pass should be executed before Loop insertion, so there is no LooEnd fake dependency
OPENVINO_ASSERT(child_ports.size() == 1 &&
is_type<snippets::lowered::BufferExpression>(child_ports.begin()->get_expr()),
"BrgemmCopyB should have one BufferExpression child");
auto grandchild_ports = child_ports.begin()->get_expr()->get_output_port(0).get_connected_ports();
iterate_through_ports(loop_ids, grandchild_ports);
const auto& input_connector = expr->get_input_port_connector(1);
auto parent_out_ports = input_connector->get_consumers();
const auto& parent_loop_ids = input_connector->get_source().get_expr()->get_loop_ids();
iterate_through_ports(parent_loop_ids, parent_out_ports);
}

return modified;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ size_t K = std::getenv("K") ? std::atoi(std::getenv("K")) : 2;
size_t N = std::getenv("N") ? std::atoi(std::getenv("N")) : 32;
size_t B1 = std::getenv("B1") ? std::atoi(std::getenv("B1")) : 1;
size_t B2 = std::getenv("B2") ? std::atoi(std::getenv("B2")) : 1;
size_t B3 = std::getenv("B3") ? std::atoi(std::getenv("B3")) : 1;

std::vector<std::vector<ov::test::InputShape>> input_shapes{
{ {{}, {{B1, 1, 1, K}}}, {{}, {{B2, 5, K, N}}} },
{ {{}, {{B1, 1, 1, K}}}, {{}, {{B2, B3, K, N}}} },
// { {{}, {{2, 1, 3, 5}}}, {{}, {{1, 3, 5, 3}}} },
// { {{}, {{3, 1, 32, 14}}}, {{}, {{1, 3, 14, 37}}} },
// { {{}, {{1, 2, 37, 23}}}, {{}, {{2, 1, 23, 37}}} },
Expand Down

0 comments on commit a386ac2

Please sign in to comment.