@@ -120,24 +120,36 @@ class SoftplusActivationFusePattern : public paddle::drr::DrrPatternBase {
120120 {" beta" , pat.Attr (" beta" )}, {" threshold" , pat.Attr (" threshold" )}};
121121
122122 if (act_type_ == paddle::dialect::HardswishOp::name ()) {
123- fused_attrs.emplace (" fuse_alpha" , res.Float32Attr (1 .0f / 6 .0f ));
124- fused_attrs.emplace (" fuse_beta" , res.Float32Attr (1 .0f / 2 .0f ));
123+ fused_attrs.emplace (" fuse_alpha" , res.DoubleAttr (1.0 / 6.0 ));
124+ fused_attrs.emplace (" fuse_beta" , res.DoubleAttr (1.0 / 2.0 ));
125125 } else if (act_type_ == paddle::dialect::HardsigmoidOp::name ()) {
126- fused_attrs.emplace (" fuse_alpha" , pat.Attr (" fuse_alpha" ));
127- fused_attrs.emplace (" fuse_beta" , pat.Attr (" fuse_beta" ));
126+ const auto &fuse_alpha = res.ComputeAttr (
127+ [](const paddle::drr::MatchContext &match_ctx) -> double {
128+ return static_cast <double >(match_ctx.Attr <float >(" fuse_alpha" ));
129+ });
130+ const auto &fuse_beta = res.ComputeAttr (
131+ [](const paddle::drr::MatchContext &match_ctx) -> double {
132+ return static_cast <double >(match_ctx.Attr <float >(" fuse_beta" ));
133+ });
134+ fused_attrs.emplace (" fuse_alpha" , fuse_alpha);
135+ fused_attrs.emplace (" fuse_beta" , fuse_beta);
128136 } else if (act_type_ == paddle::dialect::LeakyRelu_Op::name () ||
129137 act_type_ == paddle::dialect::LeakyReluOp::name ()) {
130- fused_attrs.emplace (" fuse_alpha" , pat.Attr (" fuse_alpha" ));
138+ const auto &fuse_alpha = res.ComputeAttr (
139+ [](const paddle::drr::MatchContext &match_ctx) -> double {
140+ return static_cast <double >(match_ctx.Attr <float >(" fuse_alpha" ));
141+ });
142+ fused_attrs.emplace (" fuse_alpha" , fuse_alpha);
131143 } else if (act_type_ == paddle::dialect::SwishOp::name ()) {
132- fused_attrs.emplace (" fuse_alpha" , res.Float32Attr (1 .0f ));
144+ fused_attrs.emplace (" fuse_alpha" , res.DoubleAttr (1.0 ));
133145 } else if (act_type_ == paddle::dialect::Relu6Op::name ()) {
134- fused_attrs.emplace (" fuse_beta" , res.Float32Attr (6 .0f ));
146+ fused_attrs.emplace (" fuse_beta" , res.DoubleAttr (6.0 ));
135147 }
136148
137149 fused_attrs.insert (std::make_pair (" fuse_activation" ,
138150 res.StrAttr (activation_type[act_type_])));
139- fused_attrs.insert (std::make_pair (" fuse_alpha" , res.Float32Attr (0 .0f )));
140- fused_attrs.insert (std::make_pair (" fuse_beta" , res.Float32Attr (0 .0f )));
151+ fused_attrs.insert (std::make_pair (" fuse_alpha" , res.DoubleAttr (0.0 )));
152+ fused_attrs.insert (std::make_pair (" fuse_beta" , res.DoubleAttr (0.0 )));
141153
142154 const auto &fused_softplus = res.Op (fused_softplus_name_, fused_attrs);
143155
@@ -188,8 +200,8 @@ class SoftplusGeluTanhFusePattern : public paddle::drr::DrrPatternBase {
188200 {" beta" , pat.Attr (" beta" )},
189201 {" threshold" , pat.Attr (" threshold" )},
190202 {" fuse_activation" , res.StrAttr (" gelu_tanh" )},
191- {" fuse_alpha" , res.Float32Attr (0 .0f )},
192- {" fuse_beta" , res.Float32Attr (0 .0f )}};
203+ {" fuse_alpha" , res.DoubleAttr (0.0 )},
204+ {" fuse_beta" , res.DoubleAttr (0.0 )}};
193205
194206 const auto &fused_softplus = res.Op (fused_softplus_name_, fused_attrs);
195207
@@ -244,11 +256,11 @@ class SoftplusClipFusePattern : public paddle::drr::DrrPatternBase {
244256 paddle::drr::ResultPattern res = pat.ResultPattern ();
245257
246258 const auto &fuse_alpha = res.ComputeAttr (
247- [](const paddle::drr::MatchContext &match_ctx) -> float {
259+ [](const paddle::drr::MatchContext &match_ctx) -> double {
248260 return match_ctx.Attr <double >(" value1" );
249261 });
250262 const auto &fuse_beta = res.ComputeAttr (
251- [](const paddle::drr::MatchContext &match_ctx) -> float {
263+ [](const paddle::drr::MatchContext &match_ctx) -> double {
252264 return match_ctx.Attr <double >(" value2" );
253265 });
254266
0 commit comments