Skip to content

[mlir][tosa] Improve invalid operator data types error message #140756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ class TosaProfileCompliance {
SmallVector<StringRef>
stringifyProfile(const SmallVector<ArrayRef<T>> &profileSet);

static llvm::SmallString<7> stringifyTypeInfo(const TypeInfo &typeInfo);

private:
template <typename T>
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
Expand Down
63 changes: 62 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,52 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);

if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
!maybeProfDef.value().size() && !maybeExtDef.value().size())
!maybeProfDef.value().size() && !maybeExtDef.value().size()) {
std::string message;
llvm::raw_string_ostream os(message);
os << "illegal: operation operand/result data types did not align with any "
"profile or extension, got (";

ProfileInfoDepot depot(op);
SmallVector<TypeInfo> current = depot.getInfo();
for (const auto &typeInfo : llvm::drop_end(current))
os << stringifyTypeInfo(typeInfo) << ",";
os << stringifyTypeInfo(current.back()) << ")";

// avoid polluting the error message output by outputting only
// the best match
const std::string opName = op->getName().getStringRef().str();
int maxMatches = -1;
SmallVector<TypeInfo> bestTypeInfo;
const auto searchBestMatch = [&](auto map) {
for (const auto &complianceInfos : map[opName]) {
for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
const int matches = llvm::count_if(
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
return isSameTypeInfo(std::get<0>(zipType),
std::get<1>(zipType));
});
if (matches > maxMatches) {
maxMatches = matches;
bestTypeInfo = typeInfos;
}
}
}
};
searchBestMatch(getProfileComplianceMap<Profile>());
searchBestMatch(getProfileComplianceMap<Extension>());

os << ", did you mean (";
for (const auto &typeInfo : llvm::drop_end(bestTypeInfo))
os << stringifyTypeInfo(typeInfo) << ",";
os << stringifyTypeInfo(bestTypeInfo.back()) << ")? ";
os << "Otherwise, please refer to the 'supported data types' for '"
<< opName << "' in the specification.";
op->emitOpError(message);
return failure();
}

return success();
}
Expand Down Expand Up @@ -562,3 +605,21 @@ SmallVector<StringRef> TosaProfileCompliance::stringifyProfile(

return debugStrings;
}

llvm::SmallString<7>
TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
if (typeInfo.typeID == mlir::IntegerType::getTypeID()) {
return {"i" + llvm::utostr(typeInfo.bitWidth)};
} else if (typeInfo.typeID == mlir::Float16Type::getTypeID()) {
return {"f16"};
} else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) {
return {"f32"};
} else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) {
return {"bf16"};
} else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) {
return {"fp8e4m3"};
} else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
return {"fp8e5m2"};
}
llvm_unreachable("unknown type");
}
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,10 +1248,8 @@ void TosaValidation::runOnOperation() {
return signalPassFailure();

if (!allowInvalidOpDatatypeCombinations &&
failed(profileComp.checkInvalid(op))) {
op->emitOpError("illegal: operand/result data types not supported");
failed(profileComp.checkInvalid(op)))
return signalPassFailure();
}

// Some uses of TOSA rely on the constant operands of particular
// operations.
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func.func @test_conv2d(%arg0: tensor<*xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2:

func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
%zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.conv2d' op illegal: operand/result data types not supported}}
// expected-error@+1 {{'tosa.conv2d' op illegal: operation operand/result data types did not align with any profile or extension, got (i8,i8,i8,i8,i8,i32,i8), did you mean (i8,i8,i32,i8,i8,i32,i32)?}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
Expand Down Expand Up @@ -1888,7 +1888,7 @@ func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {

// CHECK-LABEL: test_add_i1
func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
// expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}}
// expected-error@+1 {{'tosa.add' op illegal: operation operand/result data types did not align with any profile or extension, got (i1,i1,i1), did you mean (i32,i32,i32)? Otherwise, please refer to the 'supported data types' for 'tosa.add' in the specification.}}
%0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
return %0 : tensor<13x21x3xi1>
}
Expand All @@ -1897,7 +1897,7 @@ func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) ->

// CHECK-LABEL: test_mul_out_i16
func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> {
// expected-error@+1 {{'tosa.mul' op illegal: operand/result data types not supported}}
// expected-error@+1 {{'tosa.mul' op illegal: operation operand/result data types did not align with any profile or extension, got (i8,i8,i16), did you mean (i8,i8,i32)?}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
Expand Down
Loading