Skip to content

Commit

Permalink
[CINN+PIR]Support Float/Double in ConvertAttribute (PaddlePaddle#60755)
Browse files Browse the repository at this point in the history
* [CINN+PIR]Support Float/Double in ConvertAttribute

* fix typo

* fix typo

* fix typo
  • Loading branch information
Aurelius84 authored Jan 15, 2024
1 parent 5017ba8 commit 1728013
Showing 1 changed file with 38 additions and 39 deletions.
77 changes: 38 additions & 39 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,17 @@ std::vector<::pir::Value> CompatibleInfo::RealOperandSources(
}
}

utils::Attribute CompatibleInfo::ConvertAttribute(
#define CASE_ATTRIBUTE(val_type, attr_type) \
std::vector<val_type> res; \
for (auto element : attr_vec) { \
res.push_back(element.dyn_cast<::pir::attr_type>().data()); \
} \
dst_attr = res;

static utils::Attribute ConvertArrayAttribute(
const ::pir::Attribute& src_attr) {
utils::Attribute dst_attr;
if (src_attr.isa<::pir::BoolAttribute>()) {
dst_attr = src_attr.dyn_cast<::pir::BoolAttribute>().data();
} else if (src_attr.isa<::pir::FloatAttribute>()) {
dst_attr = src_attr.dyn_cast<::pir::FloatAttribute>().data();
} else if (src_attr.isa<::pir::Int32Attribute>()) {
dst_attr = src_attr.dyn_cast<::pir::Int32Attribute>().data();
} else if (src_attr.isa<::pir::StrAttribute>()) {
dst_attr = src_attr.dyn_cast<::pir::StrAttribute>().AsString();
} else if (src_attr.isa<::pir::Int64Attribute>()) {
dst_attr = src_attr.dyn_cast<::pir::Int64Attribute>().data();
} else if (src_attr.isa<::pir::DoubleAttribute>()) {
dst_attr = src_attr.dyn_cast<::pir::DoubleAttribute>().data();
} else if (src_attr.isa<paddle::dialect::IntArrayAttribute>()) {
if (src_attr.isa<paddle::dialect::IntArrayAttribute>()) {
auto& arr = src_attr.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();
Expand All @@ -147,46 +142,49 @@ utils::Attribute CompatibleInfo::ConvertAttribute(
} else if (src_attr.isa<paddle::dialect::DataTypeAttribute>()) {
auto dtype = src_attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data();
dst_attr = phi::DataTypeToString(dtype);
} else if (src_attr.isa<::pir::shape::SymbolAttribute>()) {
auto dst_attr = src_attr.dyn_cast<::pir::shape::SymbolAttribute>().data();
} else if (src_attr.isa<::pir::ArrayAttribute>()) {
auto attr_vec = src_attr.dyn_cast<::pir::ArrayAttribute>().AsVector();
if (attr_vec.size() > 0) {
if (attr_vec[0].isa<::pir::Int32Attribute>()) {
std::vector<int> vec_int32;
for (auto vec_element : attr_vec) {
vec_int32.push_back(
vec_element.dyn_cast<::pir::Int32Attribute>().data());
}
dst_attr = vec_int32;

CASE_ATTRIBUTE(int32_t, Int32Attribute)
} else if (attr_vec[0].isa<::pir::Int64Attribute>()) {
std::vector<int64_t> vec_int64;
int index = 0;
for (auto vec_element : attr_vec) {
vec_int64.push_back(
vec_element.dyn_cast<::pir::Int64Attribute>().data());
}
dst_attr = vec_int64;
CASE_ATTRIBUTE(int64_t, Int64Attribute)
} else if (attr_vec[0].isa<::pir::BoolAttribute>()) {
std::vector<bool> vec_bool;
int index = 0;
for (auto vec_element : attr_vec) {
vec_bool.push_back(
vec_element.dyn_cast<::pir::BoolAttribute>().data());
}
dst_attr = vec_bool;
CASE_ATTRIBUTE(bool, BoolAttribute)
} else if (attr_vec[0].isa<::pir::FloatAttribute>()) {
CASE_ATTRIBUTE(float, FloatAttribute)
} else if (attr_vec[0].isa<::pir::DoubleAttribute>()) {
CASE_ATTRIBUTE(double, DoubleAttribute)
} else {
LOG(FATAL)
<< "only support bool/int32/int64 attribute in ArrayAttribute";
LOG(FATAL) << "only support bool/int32/int64/float/double attribute in "
"ArrayAttribute";
}
}
} else {
LOG(FATAL) << "unknown Attribute: " << src_attr;
}
return dst_attr;
}
#undef CASE_ATTRIBUTE

#define CASE_SINGLE_ATTR(attr_type, func) \
else if (src_attr.isa<::pir::attr_type>()) dst_attr = \
src_attr.dyn_cast<::pir::attr_type>().func();

utils::Attribute CompatibleInfo::ConvertAttribute(
const ::pir::Attribute& src_attr) {
utils::Attribute dst_attr;
if (src_attr.isa<::pir::BoolAttribute>())
dst_attr = src_attr.dyn_cast<::pir::BoolAttribute>().data();
CASE_SINGLE_ATTR(FloatAttribute, data)
CASE_SINGLE_ATTR(DoubleAttribute, data)
CASE_SINGLE_ATTR(Int32Attribute, data)
CASE_SINGLE_ATTR(Int64Attribute, data)
CASE_SINGLE_ATTR(StrAttribute, AsString)
else dst_attr = ConvertArrayAttribute(src_attr); // NOLINT
return dst_attr;
}
#undef CASE_SINGLE_ATTR

utils::AttributeMap CompatibleInfo::ConvertAttributes(
const ::pir::Operation& op) {
Expand Down Expand Up @@ -231,6 +229,7 @@ cinn::common::Type CompatibleInfo::ConvertIRType(::pir::Type type) {

LOG(FATAL) << "unknown ir::Type " << type;
}
#undef CASE_TYPE

int CompatibleInfo::ShapeProduct(const std::vector<int>& shape) {
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
Expand Down

0 comments on commit 1728013

Please sign in to comment.