Skip to content

Commit deed9d3

Browse files
[Precision Depth Alignment] fix beta and threshold of paddle.nn.functional.softplus to double (PaddlePaddle#75426)
* fix beta and threshold of Softplus to double * fix test_softplus_activation_fuse_pass v1 * fix test_activation_zero * fix flaot of SoftplusDoubleGradKernel to double * add op_patches for softplus * add yaml for ops/yaml/legacy * fix infershape/operator for FLOAT64 * fix * add SoftPlusOpTranscriber * fix * fix * fix1 * fix2 * fix coverage * fix coverage2
1 parent 0ee9730 commit deed9d3

33 files changed

+448
-92
lines changed

paddle/fluid/framework/infershape_utils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,11 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
795795
infer_meta_context.EmplaceBackAttr(PADDLE_GET_CONST(float, attr));
796796
break;
797797
case phi::AttributeType::FLOAT64:
798+
if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT) {
799+
const auto val = PADDLE_GET_CONST(float, attr);
800+
infer_meta_context.EmplaceBackAttr(static_cast<double>(val));
801+
break;
802+
}
798803
infer_meta_context.EmplaceBackAttr(
799804
PADDLE_GET_CONST(double, attr));
800805
break;

paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ static phi::Attribute ConvertPirAttribute2RuntimeAttribute(
5353
return attr.dyn_cast<pir::Int32Attribute>().data();
5454
} else if (attr_type_name == "pir::FloatAttribute") {
5555
return attr.dyn_cast<pir::FloatAttribute>().data();
56+
} else if (attr_type_name == "pir::DoubleAttribute") {
57+
return attr.dyn_cast<pir::DoubleAttribute>().data();
5658
} else if (attr_type_name == "pir::BoolAttribute") {
5759
return attr.dyn_cast<pir::BoolAttribute>().data();
5860
} else if (attr_type_name == "pir::StrAttribute") {

paddle/fluid/framework/operator.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3514,6 +3514,12 @@ void OperatorWithKernel::BuildPhiKernelContext(
35143514
PADDLE_GET_CONST(float, attr_iter->second));
35153515
break;
35163516
case phi::AttributeType::FLOAT64:
3517+
if (AttrTypeID(attr_iter->second) ==
3518+
framework::proto::AttrType::FLOAT) {
3519+
const auto val = PADDLE_GET_CONST(float, attr_iter->second);
3520+
phi_kernel_context->EmplaceBackAttr(static_cast<double>(val));
3521+
break;
3522+
}
35173523
phi_kernel_context->EmplaceBackAttr(
35183524
PADDLE_GET_CONST(double, attr_iter->second));
35193525
break;

paddle/fluid/ir_adaptor/translator/op_translator.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3921,6 +3921,43 @@ struct SyncCommStreamOpTranscriber : public OpTranscriber {
39213921
}
39223922
};
39233923

3924+
struct SoftPlusOpTranscriber : public OpTranscriber {
3925+
pir::AttributeMap TranslateOpAttribute(
3926+
pir::IrContext* ctx,
3927+
const std::string& normalized_op_name,
3928+
const OpAttributeInfoList& op_attr_infos,
3929+
const OpDesc& op_desc) override {
3930+
auto& attribute_translator = AttributeTranslator::instance();
3931+
auto& op_normalizer = OpNameNormalizer::instance();
3932+
pir::AttributeMap attribute_map = {};
3933+
3934+
for (const auto& info : op_attr_infos) {
3935+
auto legacy_attr_name =
3936+
op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
3937+
VLOG(10) << "[op: " << op_desc.Type()
3938+
<< "][attr] from: " << legacy_attr_name << " to: " << info.name;
3939+
if (op_desc.HasAttr(legacy_attr_name)) {
3940+
paddle::framework::Attribute legacy_attr =
3941+
op_desc.GetAttr(legacy_attr_name);
3942+
VLOG(10) << "attribute in " << op_desc.Type()
3943+
<< " name: " << legacy_attr_name << " " << legacy_attr.index();
3944+
pir::Attribute new_attr =
3945+
attribute_translator(info.type_name, legacy_attr);
3946+
if (legacy_attr_name == "beta" || legacy_attr_name == "threshold") {
3947+
new_attr = pir::DoubleAttribute::get(
3948+
ctx,
3949+
static_cast<double>(
3950+
new_attr.dyn_cast<pir::FloatAttribute>().data()));
3951+
}
3952+
attribute_map[info.name] = new_attr;
3953+
} else {
3954+
this->HandleNonexistentAttribute(ctx, &attribute_map, info);
3955+
}
3956+
}
3957+
return attribute_map;
3958+
}
3959+
};
3960+
39243961
OpTranslator::OpTranslator() {
39253962
pir::IrContext* ctx = pir::IrContext::Instance();
39263963
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
@@ -4033,5 +4070,7 @@ OpTranslator::OpTranslator() {
40334070
WithXShapeAndAxisGradOpTranscriber<dialect::UnsqueezeGradOp>();
40344071

40354072
special_handlers["c_sync_comm_stream"] = SyncCommStreamOpTranscriber();
4073+
special_handlers["softplus"] = SoftPlusOpTranscriber();
4074+
special_handlers["softplus_grad"] = SoftPlusOpTranscriber();
40364075
}
40374076
} // namespace paddle::translator

paddle/fluid/pir/drr/include/drr_pattern_context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,8 @@ class TEST_API ResultPattern {
297297

298298
Attribute Float32Attr(float value) const;
299299

300+
Attribute DoubleAttr(double value) const;
301+
300302
Attribute VectorInt64Attr(const std::vector<int64_t>& value) const;
301303

302304
Attribute VectorInt32Attr(const std::vector<int32_t>& value) const;

paddle/fluid/pir/drr/src/pattern_context.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ Attribute ResultPattern::Float32Attr(float value) const {
205205
[=](const MatchContext& match_ctx) -> float { return value; });
206206
}
207207

208+
Attribute ResultPattern::DoubleAttr(double value) const {
209+
return ComputeAttr(
210+
[=](const MatchContext& match_ctx) -> double { return value; });
211+
}
212+
208213
Attribute ResultPattern::VectorInt64Attr(
209214
const std::vector<int64_t>& value) const {
210215
return ComputeAttr(
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
op_patches:
2+
- op_name : pd_op.softplus
3+
actions:
4+
- action : modify_attr
5+
object : beta
6+
type : pir::DoubleAttribute
7+
data : 1.0
8+
- action : modify_attr
9+
object : threshold
10+
type : pir::DoubleAttribute
11+
data : 20.0
12+
- op_name : onednn_op.fused_softplus
13+
actions:
14+
- action : modify_attr
15+
object : beta
16+
type : pir::DoubleAttribute
17+
data : 1.0
18+
- action : modify_attr
19+
object : threshold
20+
type : pir::DoubleAttribute
21+
data : 20.0
22+
- action : modify_attr
23+
object : fuse_alpha
24+
type : pir::DoubleAttribute
25+
data : 0.0
26+
- action : modify_attr
27+
object : fuse_beta
28+
type : pir::DoubleAttribute
29+
data : 0.0

paddle/fluid/pir/serialize_deserialize/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ endif()
1313

1414
file(GLOB_RECURSE YAML_PATCH_FILES "*.yaml")
1515
# change pir version when new patches are added
16-
add_definitions(-DDEVELOP_VERSION=3)
16+
add_definitions(-DDEVELOP_VERSION=0)
1717
add_definitions(-DRELEASE_VERSION=3)
1818
set(TEMPLATE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/patch/template.h.in)
1919
set(PATCH_HEADER ${CMAKE_CURRENT_BINARY_DIR}/patch/patch.h)

paddle/fluid/pir/transforms/onednn/softplus_activation_fuse_pass.cc

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,24 +120,36 @@ class SoftplusActivationFusePattern : public paddle::drr::DrrPatternBase {
120120
{"beta", pat.Attr("beta")}, {"threshold", pat.Attr("threshold")}};
121121

122122
if (act_type_ == paddle::dialect::HardswishOp::name()) {
123-
fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f / 6.0f));
124-
fused_attrs.emplace("fuse_beta", res.Float32Attr(1.0f / 2.0f));
123+
fused_attrs.emplace("fuse_alpha", res.DoubleAttr(1.0 / 6.0));
124+
fused_attrs.emplace("fuse_beta", res.DoubleAttr(1.0 / 2.0));
125125
} else if (act_type_ == paddle::dialect::HardsigmoidOp::name()) {
126-
fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha"));
127-
fused_attrs.emplace("fuse_beta", pat.Attr("fuse_beta"));
126+
const auto &fuse_alpha = res.ComputeAttr(
127+
[](const paddle::drr::MatchContext &match_ctx) -> double {
128+
return static_cast<double>(match_ctx.Attr<float>("fuse_alpha"));
129+
});
130+
const auto &fuse_beta = res.ComputeAttr(
131+
[](const paddle::drr::MatchContext &match_ctx) -> double {
132+
return static_cast<double>(match_ctx.Attr<float>("fuse_beta"));
133+
});
134+
fused_attrs.emplace("fuse_alpha", fuse_alpha);
135+
fused_attrs.emplace("fuse_beta", fuse_beta);
128136
} else if (act_type_ == paddle::dialect::LeakyRelu_Op::name() ||
129137
act_type_ == paddle::dialect::LeakyReluOp::name()) {
130-
fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha"));
138+
const auto &fuse_alpha = res.ComputeAttr(
139+
[](const paddle::drr::MatchContext &match_ctx) -> double {
140+
return static_cast<double>(match_ctx.Attr<float>("fuse_alpha"));
141+
});
142+
fused_attrs.emplace("fuse_alpha", fuse_alpha);
131143
} else if (act_type_ == paddle::dialect::SwishOp::name()) {
132-
fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f));
144+
fused_attrs.emplace("fuse_alpha", res.DoubleAttr(1.0));
133145
} else if (act_type_ == paddle::dialect::Relu6Op::name()) {
134-
fused_attrs.emplace("fuse_beta", res.Float32Attr(6.0f));
146+
fused_attrs.emplace("fuse_beta", res.DoubleAttr(6.0));
135147
}
136148

137149
fused_attrs.insert(std::make_pair("fuse_activation",
138150
res.StrAttr(activation_type[act_type_])));
139-
fused_attrs.insert(std::make_pair("fuse_alpha", res.Float32Attr(0.0f)));
140-
fused_attrs.insert(std::make_pair("fuse_beta", res.Float32Attr(0.0f)));
151+
fused_attrs.insert(std::make_pair("fuse_alpha", res.DoubleAttr(0.0)));
152+
fused_attrs.insert(std::make_pair("fuse_beta", res.DoubleAttr(0.0)));
141153

142154
const auto &fused_softplus = res.Op(fused_softplus_name_, fused_attrs);
143155

@@ -188,8 +200,8 @@ class SoftplusGeluTanhFusePattern : public paddle::drr::DrrPatternBase {
188200
{"beta", pat.Attr("beta")},
189201
{"threshold", pat.Attr("threshold")},
190202
{"fuse_activation", res.StrAttr("gelu_tanh")},
191-
{"fuse_alpha", res.Float32Attr(0.0f)},
192-
{"fuse_beta", res.Float32Attr(0.0f)}};
203+
{"fuse_alpha", res.DoubleAttr(0.0)},
204+
{"fuse_beta", res.DoubleAttr(0.0)}};
193205

194206
const auto &fused_softplus = res.Op(fused_softplus_name_, fused_attrs);
195207

@@ -244,11 +256,11 @@ class SoftplusClipFusePattern : public paddle::drr::DrrPatternBase {
244256
paddle::drr::ResultPattern res = pat.ResultPattern();
245257

246258
const auto &fuse_alpha = res.ComputeAttr(
247-
[](const paddle::drr::MatchContext &match_ctx) -> float {
259+
[](const paddle::drr::MatchContext &match_ctx) -> double {
248260
return match_ctx.Attr<double>("value1");
249261
});
250262
const auto &fuse_beta = res.ComputeAttr(
251-
[](const paddle::drr::MatchContext &match_ctx) -> float {
263+
[](const paddle::drr::MatchContext &match_ctx) -> double {
252264
return match_ctx.Attr<double>("value2");
253265
});
254266

paddle/fluid/pybind/pir.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3286,6 +3286,12 @@ void BindDrrPatternContext(pybind11::module *m) {
32863286
return self.Float32Attr(value);
32873287
},
32883288
pybind11::arg("value"))
3289+
.def(
3290+
"DoubleAttr",
3291+
[](drr::ResultPattern &self, double value) {
3292+
return self.DoubleAttr(value);
3293+
},
3294+
pybind11::arg("value"))
32893295
.def(
32903296
"VectorInt32Attr",
32913297
[](drr::ResultPattern &self, const std::vector<int32_t> &value) {

0 commit comments

Comments
 (0)