From d849cbad14730c2af35ced046df03be2fe9a87eb Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Fri, 20 Jan 2023 10:40:13 -0800 Subject: [PATCH] Make `getTypeForScalarType` safer by returning `FailureOr` (#1814) One of the potential values for a `torch_upstream::ScalarType` is `Undefined`. This means that conversion of a `ScalarType` to another type is a computation that can fail. To enforce handling of the failure case, this commit makes the two helper functions that convert `ScalarType`s into other types return `failure()` when the `ScalarType` is `Undefined`. --- include/torch-mlir/Dialect/Torch/Utils/Utils.h | 2 +- .../TorchToLinalg/TensorConstructors.cpp | 14 ++++++++++++-- .../Torch/Transforms/DecomposeComplexOps.cpp | 8 +++++--- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 11 ++++++++--- .../Torch/Transforms/SimplifyDtypeCalculations.cpp | 9 +++++++-- lib/Dialect/Torch/Utils/Utils.cpp | 10 +++++++--- 6 files changed, 40 insertions(+), 14 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index a5cbcf52ccbe..4e3f3ceccf59 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -26,7 +26,7 @@ bool getListConstructElements(Value v, SmallVectorImpl &elems); std::optional matchLegalConstantIndexIntoListOfSize(Value v, int64_t length); torch_upstream::ScalarType getScalarTypeForType(Type type); -Type getTypeForScalarType( +FailureOr getTypeForScalarType( MLIRContext *context, torch_upstream::ScalarType dtypeInt, mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 42ec8657e567..e861a1877e99 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -127,9 +127,14 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern { if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); - resultElementType = getTypeForScalarType( + FailureOr maybeResultElementType = getTypeForScalarType( op->getContext(), (torch_upstream::ScalarType)dtypeInt, IntegerType::Signless); + if (failed(maybeResultElementType)) { + return rewriter.notifyMatchFailure( + op, "unable to convert `dtypeInt` to builtin type"); + } + resultElementType = *maybeResultElementType; } // Create an uninitialized tensor of `resultSize` shape and fill it with @@ -227,9 +232,14 @@ class ConvertAtenEmptyMemoryFormatOp if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); - resultElementType = getTypeForScalarType( + FailureOr maybeResultElementType = getTypeForScalarType( op->getContext(), (torch_upstream::ScalarType)dtypeInt, IntegerType::Signless); + if (failed(maybeResultElementType)) { + return rewriter.notifyMatchFailure( + op, "unable to convert `dtypeInt` to builtin type"); + } + resultElementType = *maybeResultElementType; } // Create an uninitialized tensor of `resultSize` shape. diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 34137ed129f9..6f5984110d99 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -33,9 +33,11 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { int64_t dtypeInt; if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) return false; - Type resDtype = + FailureOr resDtype = getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); - return resDtype.isa(); + if (failed(resDtype)) + return false; + return resDtype->isa(); } // Helper function to compute the return type of the reduction function. @@ -3803,4 +3805,4 @@ std::unique_ptr> mlir::torch::Torch::createDecomposeComplexOpsPass( ArrayRef legalOps) { return std::make_unique(legalOps); -} \ No newline at end of file +} diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index f21c0dc64095..568c99f84dfa 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -81,7 +81,9 @@ using namespace mlir::torch::Torch; // ----------------------------------------------------------------------------- static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) { - return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); + FailureOr result = + getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); + return failed(result) ? Type() : *result; } static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype, @@ -563,7 +565,9 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) { /*skipRankCheck=*/true); state = updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state); - return getTypeForScalarType(scalarType.getContext(), result_type(state)); + FailureOr result = + getTypeForScalarType(scalarType.getContext(), result_type(state)); + return failed(result) ? Type() : *result; } static SmallVector> @@ -600,7 +604,8 @@ static Type getPromotedResultType(MLIRContext *context, return Type(); state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck); } - return getTypeForScalarType(context, result_type(state)); + FailureOr result = getTypeForScalarType(context, result_type(state)); + return failed(result) ? Type() : *result; } static Type getPromotedResultTypeAssumingNonZeroRank( diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 9a29b976a6e9..3c4a334b5708 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -46,10 +46,15 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op, impliedTypeFromDtype = *torchType; } else if (auto originalResultType = result.getType().dyn_cast()) { + FailureOr builtinType = + getTypeForScalarType(op->getContext(), dtypeScalarType); + if (failed(builtinType)) { + return rewriter.notifyMatchFailure( + op, "Failed to convert `dtypeScalarType` to a builtin type"); + } impliedTypeFromDtype = originalResultType.cast().getWithSizesAndDtype( - originalResultType.getOptionalSizes(), - getTypeForScalarType(op->getContext(), dtypeScalarType)); + originalResultType.getOptionalSizes(), *builtinType); } else { return rewriter.notifyMatchFailure(op, "Unimplemented: Expected result type to " diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 1d67b24e871c..d7fdf9481d5f 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -83,9 +83,10 @@ Type Torch::getTypeForTorchType( llvm::report_fatal_error("unhandled type for getTypeForTorchType"); } -Type Torch::getTypeForScalarType( - MLIRContext *context, torch_upstream::ScalarType dtypeInt, - mlir::IntegerType::SignednessSemantics signedness) { +FailureOr +Torch::getTypeForScalarType(MLIRContext *context, + torch_upstream::ScalarType dtypeInt, + mlir::IntegerType::SignednessSemantics signedness) { switch (dtypeInt) { case torch_upstream::ScalarType::Float: return Float32Type::get(context); @@ -110,6 +111,8 @@ Type Torch::getTypeForScalarType( return mlir::ComplexType::get(Float64Type::get(context)); case torch_upstream::ScalarType::ComplexDouble: return mlir::ComplexType::get(Float128Type::get(context)); + case torch_upstream::ScalarType::Undefined: + return failure(); default: llvm::report_fatal_error("unhandled type for getTypeForScalarType"); } @@ -123,6 +126,7 @@ Torch::getTorchTypeForScalarType(MLIRContext *context, return Torch::FloatType::get(context); case torch_upstream::ScalarType::Long: return Torch::IntType::get(context); + case torch_upstream::ScalarType::Undefined: default: return failure(); }