File tree Expand file tree Collapse file tree 1 file changed +4
-7
lines changed
paddle/fluid/pir/transforms/onednn Expand file tree Collapse file tree 1 file changed +4
-7
lines changed Original file line number Diff line number Diff line change @@ -101,7 +101,8 @@ class OneDNNBf16PlacementPattern : public pir::RewritePattern {
101
101
auto mkldnn_data_type = op_attr.at (" mkldnn_data_type" )
102
102
.dyn_cast <pir::StrAttribute>()
103
103
.AsString ();
104
- if (mkldnn_data_type == " int8" ) {
104
+ // Reduce repetitive match
105
+ if (mkldnn_data_type != " float32" ) {
105
106
return false ;
106
107
}
107
108
}
@@ -159,16 +160,12 @@ class OneDNNBf16PlacementPattern : public pir::RewritePattern {
159
160
.dyn_cast <paddle::dialect::DenseTensorType>()
160
161
.dtype ();
161
162
// Only float input can be converted to bfloat16
162
- if (!input_dtype.isa <pir::Float32Type>()) {
163
- return false ;
164
- }
163
+ if (!input_dtype.isa <pir::Float32Type>()) return false ;
165
164
}
166
165
} else if (type.isa <paddle::dialect::DenseTensorType>()) {
167
166
pir::Type op_dtype = pir::GetDataTypeFromValue (value);
168
167
// Only float input can be converted to bfloat16
169
- if (!op_dtype.isa <pir::Float32Type>()) {
170
- return false ;
171
- }
168
+ if (!op_dtype.isa <pir::Float32Type>()) return false ;
172
169
} else {
173
170
return false ;
174
171
}
You can’t perform that action at this time.
0 commit comments