Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,11 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
break;
case phi::AttributeType::FLOAT64:
if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT) {
const auto val = PADDLE_GET_CONST(float, attr);
infer_meta_context.EmplaceBackAttr(static_cast<double>(val));
break;
}
infer_meta_context.EmplaceBackAttr(
PADDLE_GET_CONST(double, attr));
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ static phi::Attribute ConvertPirAttribute2RuntimeAttribute(
return attr.dyn_cast<pir::Int32Attribute>().data();
} else if (attr_type_name == "pir::FloatAttribute") {
return attr.dyn_cast<pir::FloatAttribute>().data();
} else if (attr_type_name == "pir::DoubleAttribute") {
return attr.dyn_cast<pir::DoubleAttribute>().data();
} else if (attr_type_name == "pir::BoolAttribute") {
return attr.dyn_cast<pir::BoolAttribute>().data();
} else if (attr_type_name == "pir::StrAttribute") {
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3514,6 +3514,12 @@ void OperatorWithKernel::BuildPhiKernelContext(
PADDLE_GET_CONST(float, attr_iter->second));
break;
case phi::AttributeType::FLOAT64:
if (AttrTypeID(attr_iter->second) ==
framework::proto::AttrType::FLOAT) {
const auto val = PADDLE_GET_CONST(float, attr_iter->second);
phi_kernel_context->EmplaceBackAttr(static_cast<double>(val));
break;
}
phi_kernel_context->EmplaceBackAttr(
PADDLE_GET_CONST(double, attr_iter->second));
break;
Expand Down
39 changes: 39 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3921,6 +3921,43 @@ struct SyncCommStreamOpTranscriber : public OpTranscriber {
}
};

struct SoftPlusOpTranscriber : public OpTranscriber {
pir::AttributeMap TranslateOpAttribute(
pir::IrContext* ctx,
const std::string& normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc) override {
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
pir::AttributeMap attribute_map = {};

for (const auto& info : op_attr_infos) {
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
VLOG(10) << "[op: " << op_desc.Type()
<< "][attr] from: " << legacy_attr_name << " to: " << info.name;
if (op_desc.HasAttr(legacy_attr_name)) {
paddle::framework::Attribute legacy_attr =
op_desc.GetAttr(legacy_attr_name);
VLOG(10) << "attribute in " << op_desc.Type()
<< " name: " << legacy_attr_name << " " << legacy_attr.index();
pir::Attribute new_attr =
attribute_translator(info.type_name, legacy_attr);
if (legacy_attr_name == "beta" || legacy_attr_name == "threshold") {
new_attr = pir::DoubleAttribute::get(
ctx,
static_cast<double>(
new_attr.dyn_cast<pir::FloatAttribute>().data()));
}
attribute_map[info.name] = new_attr;
} else {
this->HandleNonexistentAttribute(ctx, &attribute_map, info);
}
}
return attribute_map;
}
};

OpTranslator::OpTranslator() {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down Expand Up @@ -4033,5 +4070,7 @@ OpTranslator::OpTranslator() {
WithXShapeAndAxisGradOpTranscriber<dialect::UnsqueezeGradOp>();

special_handlers["c_sync_comm_stream"] = SyncCommStreamOpTranscriber();
special_handlers["softplus"] = SoftPlusOpTranscriber();
special_handlers["softplus_grad"] = SoftPlusOpTranscriber();
}
} // namespace paddle::translator
2 changes: 2 additions & 0 deletions paddle/fluid/pir/drr/include/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ class TEST_API ResultPattern {

Attribute Float32Attr(float value) const;

Attribute DoubleAttr(double value) const;

Attribute VectorInt64Attr(const std::vector<int64_t>& value) const;

Attribute VectorInt32Attr(const std::vector<int32_t>& value) const;
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/pir/drr/src/pattern_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ Attribute ResultPattern::Float32Attr(float value) const {
[=](const MatchContext& match_ctx) -> float { return value; });
}

Attribute ResultPattern::DoubleAttr(double value) const {
return ComputeAttr(
[=](const MatchContext& match_ctx) -> double { return value; });
}

Attribute ResultPattern::VectorInt64Attr(
const std::vector<int64_t>& value) const {
return ComputeAttr(
Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/pir/serialize_deserialize/0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
op_patches:
- op_name : pd_op.softplus
actions:
- action : modify_attr
object : beta
type : pir::DoubleAttribute
data : 1.0
- action : modify_attr
object : threshold
type : pir::DoubleAttribute
data : 20.0
Comment on lines +2 to +11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和静态图的配置是不是没有对应上?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该修改是为了pir下save/load 版本兼容,具体可见说明:https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/pir/serialize_deserialize/patch/Readme.md

- op_name : onednn_op.fused_softplus
actions:
- action : modify_attr
object : beta
type : pir::DoubleAttribute
data : 1.0
- action : modify_attr
object : threshold
type : pir::DoubleAttribute
data : 20.0
- action : modify_attr
object : fuse_alpha
type : pir::DoubleAttribute
data : 0.0
- action : modify_attr
object : fuse_beta
type : pir::DoubleAttribute
data : 0.0
2 changes: 1 addition & 1 deletion paddle/fluid/pir/serialize_deserialize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ endif()

file(GLOB_RECURSE YAML_PATCH_FILES "*.yaml")
# change pir version when new patches are added
add_definitions(-DDEVELOP_VERSION=3)
add_definitions(-DDEVELOP_VERSION=0)
add_definitions(-DRELEASE_VERSION=3)
set(TEMPLATE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/patch/template.h.in)
set(PATCH_HEADER ${CMAKE_CURRENT_BINARY_DIR}/patch/patch.h)
Expand Down
38 changes: 25 additions & 13 deletions paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,36 @@ class SoftplusActivationFusePattern : public paddle::drr::DrrPatternBase {
{"beta", pat.Attr("beta")}, {"threshold", pat.Attr("threshold")}};

if (act_type_ == paddle::dialect::HardswishOp::name()) {
fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f / 6.0f));
fused_attrs.emplace("fuse_beta", res.Float32Attr(1.0f / 2.0f));
fused_attrs.emplace("fuse_alpha", res.DoubleAttr(1.0 / 6.0));
fused_attrs.emplace("fuse_beta", res.DoubleAttr(1.0 / 2.0));
} else if (act_type_ == paddle::dialect::HardsigmoidOp::name()) {
fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha"));
fused_attrs.emplace("fuse_beta", pat.Attr("fuse_beta"));
const auto &fuse_alpha = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> double {
return static_cast<double>(match_ctx.Attr<float>("fuse_alpha"));
});
const auto &fuse_beta = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> double {
return static_cast<double>(match_ctx.Attr<float>("fuse_beta"));
});
fused_attrs.emplace("fuse_alpha", fuse_alpha);
fused_attrs.emplace("fuse_beta", fuse_beta);
} else if (act_type_ == paddle::dialect::LeakyRelu_Op::name() ||
act_type_ == paddle::dialect::LeakyReluOp::name()) {
fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha"));
const auto &fuse_alpha = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> double {
return static_cast<double>(match_ctx.Attr<float>("fuse_alpha"));
});
fused_attrs.emplace("fuse_alpha", fuse_alpha);
} else if (act_type_ == paddle::dialect::SwishOp::name()) {
fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f));
fused_attrs.emplace("fuse_alpha", res.DoubleAttr(1.0));
} else if (act_type_ == paddle::dialect::Relu6Op::name()) {
fused_attrs.emplace("fuse_beta", res.Float32Attr(6.0f));
fused_attrs.emplace("fuse_beta", res.DoubleAttr(6.0));
}

fused_attrs.insert(std::make_pair("fuse_activation",
res.StrAttr(activation_type[act_type_])));
fused_attrs.insert(std::make_pair("fuse_alpha", res.Float32Attr(0.0f)));
fused_attrs.insert(std::make_pair("fuse_beta", res.Float32Attr(0.0f)));
fused_attrs.insert(std::make_pair("fuse_alpha", res.DoubleAttr(0.0)));
fused_attrs.insert(std::make_pair("fuse_beta", res.DoubleAttr(0.0)));

const auto &fused_softplus = res.Op(fused_softplus_name_, fused_attrs);

Expand Down Expand Up @@ -188,8 +200,8 @@ class SoftplusGeluTanhFusePattern : public paddle::drr::DrrPatternBase {
{"beta", pat.Attr("beta")},
{"threshold", pat.Attr("threshold")},
{"fuse_activation", res.StrAttr("gelu_tanh")},
{"fuse_alpha", res.Float32Attr(0.0f)},
{"fuse_beta", res.Float32Attr(0.0f)}};
{"fuse_alpha", res.DoubleAttr(0.0)},
{"fuse_beta", res.DoubleAttr(0.0)}};

const auto &fused_softplus = res.Op(fused_softplus_name_, fused_attrs);

Expand Down Expand Up @@ -244,11 +256,11 @@ class SoftplusClipFusePattern : public paddle::drr::DrrPatternBase {
paddle::drr::ResultPattern res = pat.ResultPattern();

const auto &fuse_alpha = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> float {
[](const paddle::drr::MatchContext &match_ctx) -> double {
return match_ctx.Attr<double>("value1");
});
const auto &fuse_beta = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> float {
[](const paddle::drr::MatchContext &match_ctx) -> double {
return match_ctx.Attr<double>("value2");
});

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3286,6 +3286,12 @@ void BindDrrPatternContext(pybind11::module *m) {
return self.Float32Attr(value);
},
pybind11::arg("value"))
.def(
"DoubleAttr",
[](drr::ResultPattern &self, double value) {
return self.DoubleAttr(value);
},
pybind11::arg("value"))
.def(
"VectorInt32Attr",
[](drr::ResultPattern &self, const std::vector<int32_t> &value) {
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/infermeta/spmd_rules/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -708,15 +708,15 @@ SpmdInfo StanhGradInfoSpmd(const DistMetaTensor& x,

// softplus
SpmdInfo SoftplusInfoSpmd(const DistMetaTensor& x,
const float beta,
const float threshold) {
const double beta,
const double threshold) {
return ElementwiseUnaryInferSpmd(x);
}

SpmdInfo SoftplusGradInfoSpmd(const DistMetaTensor& x,
const DistMetaTensor& out_grad,
const float beta,
const float threshold) {
const double beta,
const double threshold) {
return ElementwiseUnaryGradInferSpmd(x, out_grad);
}

Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/infermeta/spmd_rules/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ SpmdInfo StanhGradInfoSpmd(const DistMetaTensor& x,
const float scale_b);

SpmdInfo SoftplusInfoSpmd(const DistMetaTensor& x,
const float beta,
const float threshold);
const double beta,
const double threshold);
SpmdInfo SoftplusGradInfoSpmd(const DistMetaTensor& x,
const DistMetaTensor& out_grad,
const float beta,
const float threshold);
const double beta,
const double threshold);

SpmdInfo SoftshrinkInfoSpmd(const DistMetaTensor& x, const float threshold);
SpmdInfo SoftshrinkGradInfoSpmd(const DistMetaTensor& x,
Expand Down
16 changes: 12 additions & 4 deletions paddle/phi/kernels/activation_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ namespace phi {
float attr2, \
DenseTensor* dx);

#define DECLARE_ACT_GRAD_KERNEL_WITH_TWO_DOUBLE_ATTRS_DEPX(name, attr1, attr2) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
double attr1, \
double attr2, \
DenseTensor* dx);

#define DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(name) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
Expand Down Expand Up @@ -266,11 +275,10 @@ void SoftplusDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
float beta,
float threshold,
double beta,
double threshold,
DenseTensor* dx,
DenseTensor* ddout);

DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos);
Expand Down Expand Up @@ -317,7 +325,7 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(LogitCUDA, eps);

DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh, t_min, t_max);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_DOUBLE_ATTRS_DEPX(Softplus, beta, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, threshold, value);

Expand Down
10 changes: 9 additions & 1 deletion paddle/phi/kernels/activation_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ namespace phi {
float attr2, \
DenseTensor* out);

#define DECLARE_ACTIVATION_KERNEL_WITH_TWO_DOUBLE_ATTRS(name, attr1, attr2) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
double attr1, \
double attr2, \
DenseTensor* out);

DECLARE_ACTIVATION_KERNEL(Sin)
DECLARE_ACTIVATION_KERNEL(Cos)
DECLARE_ACTIVATION_KERNEL(Tan)
Expand Down Expand Up @@ -83,7 +91,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Logit, eps)

DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardTanh, t_min, t_max)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, beta, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_DOUBLE_ATTRS(Softplus, beta, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu, threshold, value)

Expand Down
26 changes: 21 additions & 5 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ namespace phi {
dev_ctx, &x, nullptr, &dout, dx, functor); \
}

#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_DOUBLE_ATTRS_DEPX( \
name, functor_class, attr1, attr2) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
double attr1, \
double attr2, \
DenseTensor* dx) { \
funcs::functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr1; \
*(attrs[1].second) = attr2; \
ActivationGradImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, &x, nullptr, &dout, dx, functor); \
}

#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
Expand Down Expand Up @@ -178,11 +195,10 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh,
STanhGradFunctor,
scale_a,
scale_b);

DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus,
SoftplusGradFunctor,
beta,
threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_DOUBLE_ATTRS_DEPX(Softplus,
SoftplusGradFunctor,
beta,
threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid,
HardSigmoidGradFunctor,
slope,
Expand Down
Loading
Loading