Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hasDtype checks everywhere dtypes are used in decompositions #1750

Merged
merged 1 commit into from
Jan 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
tensorType.getDtype());
tensorType.getOptionalDtype());
return resultType;
}

Expand Down Expand Up @@ -407,6 +407,11 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
op, "Expected a boolean value for half_to_float");

BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
Type resultTensorDtype = resultTensorType.getDtype();
// `torch.ops.aten._softmax`'s softmax with half to float conversion is not
// supported on CPU, but we go ahead with the decomposing.
// TODO: Add an e2e test once upstream support is added.
Expand All @@ -418,7 +423,7 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
self = rewriter.create<AtenToDtypeOp>(
loc, resultTensorType, self,
getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()),
getDtypeIntValueForType(rewriter, loc, resultTensorDtype),
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
}
Value result = getSoftmaxResult(op, self, resultTensorType, rewriter);
Expand Down Expand Up @@ -558,8 +563,8 @@ class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
return failure();
BaseTensorType valueTensorType =
inputType
.getWithSizesAndDtype(indicesTensorType.getSizes(),
inputType.getDtype())
.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
inputType.getOptionalDtype())
.cast<BaseTensorType>();

// If the dim type is `NoneType` i.e. reduce along all the dimensions.
Expand All @@ -568,7 +573,9 @@ class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
// 0th dimension.
if (dim.getType().isa<Torch::NoneType>()) {
BaseTensorType flattenType =
inputType.getWithSizesAndDtype({kUnknownSize}, inputType.getDtype())
inputType
.getWithSizesAndDtype({kUnknownSize},
inputType.getOptionalDtype())
.cast<BaseTensorType>();
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value end = rewriter.create<ConstantIntOp>(
Expand Down Expand Up @@ -850,7 +857,7 @@ class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
sizes.append(inputShape.begin(), inputShape.end());
sizes[cstDim] = kUnknownSize;
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
selfTy.getDtype());
selfTy.getOptionalDtype());
Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, negShift, constNone, constOne);
Value slice1 = rewriter.create<AtenSliceTensorOp>(
Expand Down Expand Up @@ -984,7 +991,7 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
reshapedSizes.push_back(scaledSize);
}

Type dtype = self.getType().cast<ValueTensorType>().getDtype();
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
Type unsqueezedType = ValueTensorType::get(
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
Type expandedType = ValueTensorType::get(
Expand Down Expand Up @@ -1420,10 +1427,8 @@ class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
}

// TODO: Handle integer type operands.
if (!input.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non-floating point dtype");
}
Expand Down Expand Up @@ -1994,7 +1999,7 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
for (int i = 0; i < axis; i++)
meanVarSizes[i] = input.getSizes()[i];
auto meanVarType = input.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarSizes), input.getDtype());
llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
loc, op.getType(), meanVarType, meanVarType, op.getInput(),
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
Expand Down Expand Up @@ -2229,7 +2234,7 @@ class DecomposeAtenNativeBatchNormOp

SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
runningStatsShapeInt[1] = kUnknownSize;
Type dtype = input.getType().cast<ValueTensorType>().getDtype();
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
Type reshapeType = ValueTensorType::get(
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);

Expand Down Expand Up @@ -2346,6 +2351,10 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType =
op.getSelf().getType().template cast<BaseTensorType>();
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
Expand All @@ -2366,6 +2375,10 @@ class DecomposeAtenFullOp : public OpRewritePattern<AtenFullOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
if (!outTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
SmallVector<int64_t> empty;
auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Expand Down Expand Up @@ -2406,7 +2419,7 @@ class DecomposeAtenLinearOp : public OpRewritePattern<AtenLinearOp> {
SmallVector<int64_t> transposeShape =
llvm::to_vector(llvm::reverse(weightType.getSizes()));
Type transposeType = weightType.getWithSizesAndDtype(
llvm::makeArrayRef(transposeShape), weightType.getDtype());
llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype());
Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight);

Expand Down Expand Up @@ -2469,6 +2482,10 @@ class DecomposeAtenFullLikeOp : public OpRewritePattern<AtenFullLikeOp> {
LogicalResult matchAndRewrite(AtenFullLikeOp op,
PatternRewriter &rewriter) const override {
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
if (!outTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
SmallVector<int64_t> empty;
auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Expand Down Expand Up @@ -2525,7 +2542,12 @@ class DecomposeAten_ToCopyOp : public OpRewritePattern<Aten_ToCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
PatternRewriter &rewriter) const override {
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
auto resultType = op.getType().cast<BaseTensorType>();
if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
Type resultDtype = resultType.getDtype();
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
resultDtype);
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
Expand All @@ -2545,7 +2567,12 @@ class DecomposeAtenCopyOp : public OpRewritePattern<AtenCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopyOp op,
PatternRewriter &rewriter) const override {
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
auto resultType = op.getType().cast<BaseTensorType>();
if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
Type resultDtype = resultType.getDtype();
Value srcToDtype =
convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype);
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, op.getType(), srcToDtype,
Expand All @@ -2565,6 +2592,10 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
Expand Down Expand Up @@ -2907,6 +2938,10 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
Type outputType = op.getType();
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
if (!outputTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(op,
"expected result type to have a dtype");
}
Type newOutputType = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type());
if (!inputTensorTy.hasDtype() ||
Expand Down Expand Up @@ -3096,8 +3131,8 @@ class DecomposeAtenSelectScatterOp
} else {
sizes.resize(srcShape.size() + 1, kUnknownSize);
}
Type srcType = srcTensorType.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
srcTensorType.getDtype());
Type srcType = srcTensorType.getWithSizesAndDtype(
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype());
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
Expand Down Expand Up @@ -3196,7 +3231,7 @@ class DecomposeAtenMseLossOp : public OpRewritePattern<AtenMseLossOp> {
BaseTensorType subType =
inputType
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
resultType.getDtype())
resultType.getOptionalDtype())
.cast<BaseTensorType>();

Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
Expand Down Expand Up @@ -3232,6 +3267,10 @@ class DecomposeAtenRandintLowOp : public OpRewritePattern<AtenRandintLowOp> {
Location loc = op.getLoc();
Type resultType = op.getType();
BaseTensorType resultTensorType = resultType.cast<BaseTensorType>();
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}

int64_t cstLow, cstHigh;
if (!matchPattern(op.getLow(), m_TorchConstantInt(&cstLow)))
Expand Down