Skip to content

Commit

Permalink
Add hasDtype checks everywhere dtypes are used in decompositions (#…
Browse files Browse the repository at this point in the history
…1750)

There are several decompositions that assume the operands of the op
have dtypes available; however, the only time dtypes are guaranteed to
be present is when the graph has reached the backend contract. In
general, every pass that happens before reaching the backend contract
should not assume dtypes are available and should use `hasDtype` to
check first.

This commit adds `hasDtype` checks to every decomposition that uses
dtypes.
  • Loading branch information
ramiro050 authored Jan 3, 2023
1 parent 273664d commit d44bdd2
Showing 1 changed file with 58 additions and 19 deletions.
77 changes: 58 additions & 19 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
tensorType.getDtype());
tensorType.getOptionalDtype());
return resultType;
}

Expand Down Expand Up @@ -407,6 +407,11 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
op, "Expected a boolean value for half_to_float");

BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
Type resultTensorDtype = resultTensorType.getDtype();
// `torch.ops.aten._softmax`'s softmax with half to float conversion is not
// supported on CPU, but we go ahead with the decomposing.
// TODO: Add an e2e test once upstream support is added.
Expand All @@ -418,7 +423,7 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
self = rewriter.create<AtenToDtypeOp>(
loc, resultTensorType, self,
getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()),
getDtypeIntValueForType(rewriter, loc, resultTensorDtype),
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
}
Value result = getSoftmaxResult(op, self, resultTensorType, rewriter);
Expand Down Expand Up @@ -558,8 +563,8 @@ class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
return failure();
BaseTensorType valueTensorType =
inputType
.getWithSizesAndDtype(indicesTensorType.getSizes(),
inputType.getDtype())
.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
inputType.getOptionalDtype())
.cast<BaseTensorType>();

// If the dim type is `NoneType` i.e. reduce along all the dimensions.
Expand All @@ -568,7 +573,9 @@ class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
// 0th dimension.
if (dim.getType().isa<Torch::NoneType>()) {
BaseTensorType flattenType =
inputType.getWithSizesAndDtype({kUnknownSize}, inputType.getDtype())
inputType
.getWithSizesAndDtype({kUnknownSize},
inputType.getOptionalDtype())
.cast<BaseTensorType>();
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value end = rewriter.create<ConstantIntOp>(
Expand Down Expand Up @@ -923,7 +930,7 @@ class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
sizes.append(inputShape.begin(), inputShape.end());
sizes[cstDim] = kUnknownSize;
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
selfTy.getDtype());
selfTy.getOptionalDtype());
Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, negShift, constNone, constOne);
Value slice1 = rewriter.create<AtenSliceTensorOp>(
Expand Down Expand Up @@ -1057,7 +1064,7 @@ class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
reshapedSizes.push_back(scaledSize);
}

Type dtype = self.getType().cast<ValueTensorType>().getDtype();
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
Type unsqueezedType = ValueTensorType::get(
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
Type expandedType = ValueTensorType::get(
Expand Down Expand Up @@ -1493,10 +1500,8 @@ class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
}

// TODO: Handle integer type operands.
if (!input.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non-floating point dtype");
}
Expand Down Expand Up @@ -2067,7 +2072,7 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
for (int i = 0; i < axis; i++)
meanVarSizes[i] = input.getSizes()[i];
auto meanVarType = input.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarSizes), input.getDtype());
llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
loc, op.getType(), meanVarType, meanVarType, op.getInput(),
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
Expand Down Expand Up @@ -2302,7 +2307,7 @@ class DecomposeAtenNativeBatchNormOp

SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
runningStatsShapeInt[1] = kUnknownSize;
Type dtype = input.getType().cast<ValueTensorType>().getDtype();
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
Type reshapeType = ValueTensorType::get(
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);

Expand Down Expand Up @@ -2419,6 +2424,10 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType =
op.getSelf().getType().template cast<BaseTensorType>();
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
Expand All @@ -2439,6 +2448,10 @@ class DecomposeAtenFullOp : public OpRewritePattern<AtenFullOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
if (!outTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
SmallVector<int64_t> empty;
auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Expand Down Expand Up @@ -2479,7 +2492,7 @@ class DecomposeAtenLinearOp : public OpRewritePattern<AtenLinearOp> {
SmallVector<int64_t> transposeShape =
llvm::to_vector(llvm::reverse(weightType.getSizes()));
Type transposeType = weightType.getWithSizesAndDtype(
llvm::makeArrayRef(transposeShape), weightType.getDtype());
llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype());
Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight);

Expand Down Expand Up @@ -2542,6 +2555,10 @@ class DecomposeAtenFullLikeOp : public OpRewritePattern<AtenFullLikeOp> {
LogicalResult matchAndRewrite(AtenFullLikeOp op,
PatternRewriter &rewriter) const override {
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
if (!outTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
SmallVector<int64_t> empty;
auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Expand Down Expand Up @@ -2598,7 +2615,12 @@ class DecomposeAten_ToCopyOp : public OpRewritePattern<Aten_ToCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
PatternRewriter &rewriter) const override {
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
auto resultType = op.getType().cast<BaseTensorType>();
if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
Type resultDtype = resultType.getDtype();
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
resultDtype);
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
Expand All @@ -2618,7 +2640,12 @@ class DecomposeAtenCopyOp : public OpRewritePattern<AtenCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopyOp op,
PatternRewriter &rewriter) const override {
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
auto resultType = op.getType().cast<BaseTensorType>();
if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
Type resultDtype = resultType.getDtype();
Value srcToDtype =
convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype);
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, op.getType(), srcToDtype,
Expand All @@ -2638,6 +2665,10 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
Expand Down Expand Up @@ -2980,6 +3011,10 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
Type outputType = op.getType();
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
if (!outputTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(op,
"expected result type to have a dtype");
}
Type newOutputType = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type());
if (!inputTensorTy.hasDtype() ||
Expand Down Expand Up @@ -3169,8 +3204,8 @@ class DecomposeAtenSelectScatterOp
} else {
sizes.resize(srcShape.size() + 1, kUnknownSize);
}
Type srcType = srcTensorType.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
srcTensorType.getDtype());
Type srcType = srcTensorType.getWithSizesAndDtype(
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype());
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
Expand Down Expand Up @@ -3269,7 +3304,7 @@ class DecomposeAtenMseLossOp : public OpRewritePattern<AtenMseLossOp> {
BaseTensorType subType =
inputType
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
resultType.getDtype())
resultType.getOptionalDtype())
.cast<BaseTensorType>();

Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
Expand Down Expand Up @@ -3305,6 +3340,10 @@ class DecomposeAtenRandintLowOp : public OpRewritePattern<AtenRandintLowOp> {
Location loc = op.getLoc();
Type resultType = op.getType();
BaseTensorType resultTensorType = resultType.cast<BaseTensorType>();
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}

int64_t cstLow, cstHigh;
if (!matchPattern(op.getLow(), m_TorchConstantInt(&cstLow)))
Expand Down

0 comments on commit d44bdd2

Please sign in to comment.