diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 712e7e813b68..4a4f9a5619f2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -71,7 +71,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op, Type resultType = tensorType.getWithSizesAndDtype( sizes.size() == 0 ? std::optional>() : llvm::makeArrayRef(sizes), - tensorType.getDtype()); + tensorType.getOptionalDtype()); return resultType; } @@ -407,6 +407,11 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern { op, "Expected a boolean value for half_to_float"); BaseTensorType resultTensorType = op.getType().cast(); + if (!resultTensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have a dtype"); + } + Type resultTensorDtype = resultTensorType.getDtype(); // `torch.ops.aten._softmax`'s softmax with half to float conversion is not // supported on CPU, but we go ahead with the decomposing. // TODO: Add an e2e test once upstream support is added. @@ -418,7 +423,7 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern { Value cstFalse = rewriter.create(loc, false); self = rewriter.create( loc, resultTensorType, self, - getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()), + getDtypeIntValueForType(rewriter, loc, resultTensorDtype), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } Value result = getSoftmaxResult(op, self, resultTensorType, rewriter); @@ -558,8 +563,8 @@ class DecomposeAtenArgMaxOp : public OpRewritePattern { return failure(); BaseTensorType valueTensorType = inputType - .getWithSizesAndDtype(indicesTensorType.getSizes(), - inputType.getDtype()) + .getWithSizesAndDtype(indicesTensorType.getOptionalSizes(), + inputType.getOptionalDtype()) .cast(); // If the dim type is `NoneType` i.e. reduce along all the dimensions. @@ -568,7 +573,9 @@ class DecomposeAtenArgMaxOp : public OpRewritePattern { // 0th dimension. if (dim.getType().isa()) { BaseTensorType flattenType = - inputType.getWithSizesAndDtype({kUnknownSize}, inputType.getDtype()) + inputType + .getWithSizesAndDtype({kUnknownSize}, + inputType.getOptionalDtype()) .cast(); dim = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value end = rewriter.create( @@ -923,7 +930,7 @@ class DecomposeAtenRollOp : public OpRewritePattern { sizes.append(inputShape.begin(), inputShape.end()); sizes[cstDim] = kUnknownSize; Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), - selfTy.getDtype()); + selfTy.getOptionalDtype()); Value slice0 = rewriter.create( loc, sliceTy, input, dim, negShift, constNone, constOne); Value slice1 = rewriter.create( @@ -1057,7 +1064,7 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { reshapedSizes.push_back(scaledSize); } - Type dtype = self.getType().cast().getDtype(); + Type dtype = self.getType().cast().getOptionalDtype(); Type unsqueezedType = ValueTensorType::get( context, llvm::makeArrayRef(unsqueezedIntSizes), dtype); Type expandedType = ValueTensorType::get( @@ -1493,10 +1500,8 @@ class DecomposeAtenAddmmOp : public OpRewritePattern { } // TODO: Handle integer type operands. - if (!input.getType() - .cast() - .getDtype() - .isa()) { + auto inputType = input.getType().cast(); + if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "unimplemented: non-floating point dtype"); } @@ -2067,7 +2072,7 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { for (int i = 0; i < axis; i++) meanVarSizes[i] = input.getSizes()[i]; auto meanVarType = input.getWithSizesAndDtype( - llvm::makeArrayRef(meanVarSizes), input.getDtype()); + llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype()); auto nativeLayerNorm = rewriter.create( loc, op.getType(), meanVarType, meanVarType, op.getInput(), op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps()); @@ -2302,7 +2307,7 @@ class DecomposeAtenNativeBatchNormOp SmallVector runningStatsShapeInt(inputRank, 1); runningStatsShapeInt[1] = kUnknownSize; - Type dtype = input.getType().cast().getDtype(); + Type dtype = input.getType().cast().getOptionalDtype(); Type reshapeType = ValueTensorType::get( context, llvm::makeArrayRef(runningStatsShapeInt), dtype); @@ -2419,6 +2424,10 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern { if (dtype.getType().isa()) { BaseTensorType tensorType = op.getSelf().getType().template cast(); + if (!tensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected input tensor to have a dtype"); + } dtype = getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); } @@ -2439,6 +2448,10 @@ class DecomposeAtenFullOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); BaseTensorType outTy = op.getType().template cast(); + if (!outTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have a dtype"); + } SmallVector empty; auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); @@ -2479,7 +2492,7 @@ class DecomposeAtenLinearOp : public OpRewritePattern { SmallVector transposeShape = llvm::to_vector(llvm::reverse(weightType.getSizes())); Type transposeType = weightType.getWithSizesAndDtype( - llvm::makeArrayRef(transposeShape), weightType.getDtype()); + llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype()); Value transposeWeight = rewriter.create(loc, transposeType, weight); @@ -2542,6 +2555,10 @@ class DecomposeAtenFullLikeOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenFullLikeOp op, PatternRewriter &rewriter) const override { BaseTensorType outTy = op.getType().template cast(); + if (!outTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have a dtype"); + } SmallVector empty; auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); @@ -2598,7 +2615,12 @@ class DecomposeAten_ToCopyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ToCopyOp op, PatternRewriter &rewriter) const override { - Type resultDtype = op.getType().cast().getDtype(); + auto resultType = op.getType().cast(); + if (!resultType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have a dtype"); + } + Type resultDtype = resultType.getDtype(); Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0, resultDtype); Value emptyTensor = rewriter.create( @@ -2618,7 +2640,12 @@ class DecomposeAtenCopyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCopyOp op, PatternRewriter &rewriter) const override { - Type resultDtype = op.getType().cast().getDtype(); + auto resultType = op.getType().cast(); + if (!resultType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have a dtype"); + } + Type resultDtype = resultType.getDtype(); Value srcToDtype = convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype); rewriter.replaceOpWithNewOp(op, op.getType(), srcToDtype, @@ -2638,6 +2665,10 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern { Value dtype = op.getDtype(); if (dtype.getType().isa()) { BaseTensorType tensorType = op.getSelf().getType().cast(); + if (!tensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected input tensor to have a dtype"); + } dtype = getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); } @@ -2980,6 +3011,10 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, BaseTensorType inputTensorTy = self.getType().cast(); Type outputType = op.getType(); BaseTensorType outputTensorType = outputType.cast(); + if (!outputTensorType.hasDtype()) { + return rewriter.notifyMatchFailure(op, + "expected result type to have a dtype"); + } Type newOutputType = outputTensorType.getWithSizesAndDtype( outputTensorType.getSizes(), rewriter.getF64Type()); if (!inputTensorTy.hasDtype() || @@ -3169,8 +3204,8 @@ class DecomposeAtenSelectScatterOp } else { sizes.resize(srcShape.size() + 1, kUnknownSize); } - Type srcType = srcTensorType.getWithSizesAndDtype(llvm::makeArrayRef(sizes), - srcTensorType.getDtype()); + Type srcType = srcTensorType.getWithSizesAndDtype( + llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype()); src = rewriter.create(loc, srcType, src, dim); rewriter.replaceOpWithNewOp( op, op.getSelf().getType(), self, src, dim, start, startPlusOne, @@ -3269,7 +3304,7 @@ class DecomposeAtenMseLossOp : public OpRewritePattern { BaseTensorType subType = inputType .getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()), - resultType.getDtype()) + resultType.getOptionalDtype()) .cast(); Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); @@ -3305,6 +3340,10 @@ class DecomposeAtenRandintLowOp : public OpRewritePattern { Location loc = op.getLoc(); Type resultType = op.getType(); BaseTensorType resultTensorType = resultType.cast(); + if (!resultTensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have a dtype"); + } int64_t cstLow, cstHigh; if (!matchPattern(op.getLow(), m_TorchConstantInt(&cstLow)))