Skip to content

Commit

Permalink
Make getTypeForScalarType safer by returning FailureOr<Type> (llv…
Browse files Browse the repository at this point in the history
…m#1814)

One of the potential values for a `torch_upstream::ScalarType` is
`Undefined`. This means that conversion of a `ScalarType` to another
type is a computation that can fail. To enforce handling of the
failure case, this commit makes the two helper functions that convert
`ScalarType`s into other types return `failure()` when the
`ScalarType` is `Undefined`.
  • Loading branch information
ramiro050 authored Jan 20, 2023
1 parent d3c6183 commit d849cba
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 14 deletions.
2 changes: 1 addition & 1 deletion include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
int64_t length);
torch_upstream::ScalarType getScalarTypeForType(Type type);
Type getTypeForScalarType(
FailureOr<Type> getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);

Expand Down
14 changes: 12 additions & 2 deletions lib/Conversion/TorchToLinalg/TensorConstructors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,14 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern<OpTy> {
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none");
resultElementType = getTypeForScalarType(
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless);
if (failed(maybeResultElementType)) {
return rewriter.notifyMatchFailure(
op, "unable to convert `dtypeInt` to builtin type");
}
resultElementType = *maybeResultElementType;
}

// Create an uninitialized tensor of `resultSize` shape and fill it with
Expand Down Expand Up @@ -227,9 +232,14 @@ class ConvertAtenEmptyMemoryFormatOp
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none");
resultElementType = getTypeForScalarType(
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless);
if (failed(maybeResultElementType)) {
return rewriter.notifyMatchFailure(
op, "unable to convert `dtypeInt` to builtin type");
}
resultElementType = *maybeResultElementType;
}

// Create an uninitialized tensor of `resultSize` shape.
Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
int64_t dtypeInt;
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
return false;
Type resDtype =
FailureOr<Type> resDtype =
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
return resDtype.isa<mlir::FloatType>();
if (failed(resDtype))
return false;
return resDtype->isa<mlir::FloatType>();
}

// Helper function to compute the return type of the reduction function.
Expand Down Expand Up @@ -3803,4 +3805,4 @@ std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createDecomposeComplexOpsPass(
ArrayRef<std::string> legalOps) {
return std::make_unique<DecomposeComplexOpsPass>(legalOps);
}
}
11 changes: 8 additions & 3 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ using namespace mlir::torch::Torch;
// -----------------------------------------------------------------------------

static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
FailureOr<Type> result =
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
return failed(result) ? Type() : *result;
}

static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
Expand Down Expand Up @@ -563,7 +565,9 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
/*skipRankCheck=*/true);
state =
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
return getTypeForScalarType(scalarType.getContext(), result_type(state));
FailureOr<Type> result =
getTypeForScalarType(scalarType.getContext(), result_type(state));
return failed(result) ? Type() : *result;
}

static SmallVector<std::optional<bool>>
Expand Down Expand Up @@ -600,7 +604,8 @@ static Type getPromotedResultType(MLIRContext *context,
return Type();
state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck);
}
return getTypeForScalarType(context, result_type(state));
FailureOr<Type> result = getTypeForScalarType(context, result_type(state));
return failed(result) ? Type() : *result;
}

static Type getPromotedResultTypeAssumingNonZeroRank(
Expand Down
9 changes: 7 additions & 2 deletions lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,15 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
impliedTypeFromDtype = *torchType;
} else if (auto originalResultType =
result.getType().dyn_cast<BaseTensorType>()) {
FailureOr<Type> builtinType =
getTypeForScalarType(op->getContext(), dtypeScalarType);
if (failed(builtinType)) {
return rewriter.notifyMatchFailure(
op, "Failed to convert `dtypeScalarType` to a builtin type");
}
impliedTypeFromDtype =
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
originalResultType.getOptionalSizes(),
getTypeForScalarType(op->getContext(), dtypeScalarType));
originalResultType.getOptionalSizes(), *builtinType);
} else {
return rewriter.notifyMatchFailure(op,
"Unimplemented: Expected result type to "
Expand Down
10 changes: 7 additions & 3 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ Type Torch::getTypeForTorchType(
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
}

Type Torch::getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness) {
FailureOr<Type>
Torch::getTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness) {
switch (dtypeInt) {
case torch_upstream::ScalarType::Float:
return Float32Type::get(context);
Expand All @@ -110,6 +111,8 @@ Type Torch::getTypeForScalarType(
return mlir::ComplexType::get(Float64Type::get(context));
case torch_upstream::ScalarType::ComplexDouble:
return mlir::ComplexType::get(Float128Type::get(context));
case torch_upstream::ScalarType::Undefined:
return failure();
default:
llvm::report_fatal_error("unhandled type for getTypeForScalarType");
}
Expand All @@ -123,6 +126,7 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
return Torch::FloatType::get(context);
case torch_upstream::ScalarType::Long:
return Torch::IntType::get(context);
case torch_upstream::ScalarType::Undefined:
default:
return failure();
}
Expand Down

0 comments on commit d849cba

Please sign in to comment.