From 172801318252405154c6441933701f254c1991a1 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 15 Jan 2024 11:49:36 +0800 Subject: [PATCH] [CINN+PIR]Support Float/Double in ConvertAttribute (#60755) * [CINN+PIR]Support Float/Double in ConvertAttribute * fix typo * fix typo * fix typo --- paddle/cinn/hlir/framework/pir/utils.cc | 77 ++++++++++++------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/paddle/cinn/hlir/framework/pir/utils.cc b/paddle/cinn/hlir/framework/pir/utils.cc index a0a6f5f15614b..bba7b3da0b3d4 100644 --- a/paddle/cinn/hlir/framework/pir/utils.cc +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -123,22 +123,17 @@ std::vector<::pir::Value> CompatibleInfo::RealOperandSources( } } -utils::Attribute CompatibleInfo::ConvertAttribute( +#define CASE_ATTRIBUTE(val_type, attr_type) \ + std::vector res; \ + for (auto element : attr_vec) { \ + res.push_back(element.dyn_cast<::pir::attr_type>().data()); \ + } \ + dst_attr = res; + +static utils::Attribute ConvertArrayAttribute( const ::pir::Attribute& src_attr) { utils::Attribute dst_attr; - if (src_attr.isa<::pir::BoolAttribute>()) { - dst_attr = src_attr.dyn_cast<::pir::BoolAttribute>().data(); - } else if (src_attr.isa<::pir::FloatAttribute>()) { - dst_attr = src_attr.dyn_cast<::pir::FloatAttribute>().data(); - } else if (src_attr.isa<::pir::Int32Attribute>()) { - dst_attr = src_attr.dyn_cast<::pir::Int32Attribute>().data(); - } else if (src_attr.isa<::pir::StrAttribute>()) { - dst_attr = src_attr.dyn_cast<::pir::StrAttribute>().AsString(); - } else if (src_attr.isa<::pir::Int64Attribute>()) { - dst_attr = src_attr.dyn_cast<::pir::Int64Attribute>().data(); - } else if (src_attr.isa<::pir::DoubleAttribute>()) { - dst_attr = src_attr.dyn_cast<::pir::DoubleAttribute>().data(); - } else if (src_attr.isa()) { + if (src_attr.isa()) { auto& arr = src_attr.dyn_cast() .data() .GetData(); @@ -147,46 +142,49 @@ utils::Attribute CompatibleInfo::ConvertAttribute( } else if (src_attr.isa()) { auto dtype = src_attr.dyn_cast().data(); dst_attr = phi::DataTypeToString(dtype); - } else if (src_attr.isa<::pir::shape::SymbolAttribute>()) { - auto dst_attr = src_attr.dyn_cast<::pir::shape::SymbolAttribute>().data(); } else if (src_attr.isa<::pir::ArrayAttribute>()) { auto attr_vec = src_attr.dyn_cast<::pir::ArrayAttribute>().AsVector(); if (attr_vec.size() > 0) { if (attr_vec[0].isa<::pir::Int32Attribute>()) { - std::vector vec_int32; - for (auto vec_element : attr_vec) { - vec_int32.push_back( - vec_element.dyn_cast<::pir::Int32Attribute>().data()); - } - dst_attr = vec_int32; - + CASE_ATTRIBUTE(int32_t, Int32Attribute) } else if (attr_vec[0].isa<::pir::Int64Attribute>()) { - std::vector vec_int64; - int index = 0; - for (auto vec_element : attr_vec) { - vec_int64.push_back( - vec_element.dyn_cast<::pir::Int64Attribute>().data()); - } - dst_attr = vec_int64; + CASE_ATTRIBUTE(int64_t, Int64Attribute) } else if (attr_vec[0].isa<::pir::BoolAttribute>()) { - std::vector vec_bool; - int index = 0; - for (auto vec_element : attr_vec) { - vec_bool.push_back( - vec_element.dyn_cast<::pir::BoolAttribute>().data()); - } - dst_attr = vec_bool; + CASE_ATTRIBUTE(bool, BoolAttribute) + } else if (attr_vec[0].isa<::pir::FloatAttribute>()) { + CASE_ATTRIBUTE(float, FloatAttribute) + } else if (attr_vec[0].isa<::pir::DoubleAttribute>()) { + CASE_ATTRIBUTE(double, DoubleAttribute) } else { - LOG(FATAL) - << "only support bool/int32/int64 attribute in ArrayAttribute"; + LOG(FATAL) << "only support bool/int32/int64/float/double attribute in " + "ArrayAttribute"; } } } else { LOG(FATAL) << "unknown Attribute: " << src_attr; } + return dst_attr; +} +#undef CASE_ATTRIBUTE +#define CASE_SINGLE_ATTR(attr_type, func) \ + else if (src_attr.isa<::pir::attr_type>()) dst_attr = \ + src_attr.dyn_cast<::pir::attr_type>().func(); + +utils::Attribute CompatibleInfo::ConvertAttribute( + const ::pir::Attribute& src_attr) { + utils::Attribute dst_attr; + if (src_attr.isa<::pir::BoolAttribute>()) + dst_attr = src_attr.dyn_cast<::pir::BoolAttribute>().data(); + CASE_SINGLE_ATTR(FloatAttribute, data) + CASE_SINGLE_ATTR(DoubleAttribute, data) + CASE_SINGLE_ATTR(Int32Attribute, data) + CASE_SINGLE_ATTR(Int64Attribute, data) + CASE_SINGLE_ATTR(StrAttribute, AsString) + else dst_attr = ConvertArrayAttribute(src_attr); // NOLINT return dst_attr; } +#undef CASE_SINGLE_ATTR utils::AttributeMap CompatibleInfo::ConvertAttributes( const ::pir::Operation& op) { @@ -231,6 +229,7 @@ cinn::common::Type CompatibleInfo::ConvertIRType(::pir::Type type) { LOG(FATAL) << "unknown ir::Type " << type; } +#undef CASE_TYPE int CompatibleInfo::ShapeProduct(const std::vector& shape) { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies());