Skip to content

Commit 70d6236

Browse files
committed
reduce repetitive match
1 parent f950ef4 commit 70d6236

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

paddle/fluid/pir/transforms/onednn/cpu_bfloat16_placement_pass.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ class OneDNNBf16PlacementPattern : public pir::RewritePattern {
101101
auto mkldnn_data_type = op_attr.at("mkldnn_data_type")
102102
.dyn_cast<pir::StrAttribute>()
103103
.AsString();
104-
if (mkldnn_data_type == "int8") {
104+
// Reduce repetitive match
105+
if (mkldnn_data_type != "float32") {
105106
return false;
106107
}
107108
}
@@ -159,16 +160,12 @@ class OneDNNBf16PlacementPattern : public pir::RewritePattern {
159160
.dyn_cast<paddle::dialect::DenseTensorType>()
160161
.dtype();
161162
// Only float input can be converted to bfloat16
162-
if (!input_dtype.isa<pir::Float32Type>()) {
163-
return false;
164-
}
163+
if (!input_dtype.isa<pir::Float32Type>()) return false;
165164
}
166165
} else if (type.isa<paddle::dialect::DenseTensorType>()) {
167166
pir::Type op_dtype = pir::GetDataTypeFromValue(value);
168167
// Only float input can be converted to bfloat16
169-
if (!op_dtype.isa<pir::Float32Type>()) {
170-
return false;
171-
}
168+
if (!op_dtype.isa<pir::Float32Type>()) return false;
172169
} else {
173170
return false;
174171
}

0 commit comments

Comments
 (0)