From 689efc89175cc339ca6a1df88be7d24172906c32 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 8 Jun 2024 09:36:32 +0800 Subject: [PATCH] [Torch] fix toBuiltinTensor() (#3415) * Let `toBuiltinTensor()` reflects the original dtype of `!torch.vtensor`. * Backend handles dtype conversion themselves. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 18 ++++---- lib/Conversion/TorchToLinalg/Linear.cpp | 19 ++++---- lib/Dialect/Torch/IR/TorchOps.cpp | 46 ++++++++----------- lib/Dialect/Torch/IR/TorchTypes.cpp | 11 ++--- .../Transforms/BackendTypeConversion.cpp | 22 ++++++++- 5 files changed, 60 insertions(+), 56 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index b26e1ea3a5f1..b6cc7cdd0ac9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -737,7 +737,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::numeric_limits::lowest())) return failure(); auto minSplatAttr = SplatElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDtype), + resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, minValue)); min = rewriter.create( binder.getLoc(), resultType, minSplatAttr); @@ -748,7 +748,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::numeric_limits::max())) return failure(); auto maxSplatAttr = SplatElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDtype), + resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, maxValue)); max = rewriter.create( binder.getLoc(), resultType, maxSplatAttr); @@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.op->hasAttr("torch.onnx.value_float") && !binder.f32FloatAttr(floatValue, "value_float", 0.0)) { auto splatAttr = - SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getFloatAttr(dtype, floatValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); @@ -872,7 +872,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.op->hasAttr("torch.onnx.value_int") && !binder.s64IntegerAttr(intValue, "value_int", 0)) { auto splatAttr = - SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getIntegerAttr(dtype, intValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); @@ -932,8 +932,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( for (auto intVal : intValues) { apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal)); } - auto attr = DenseElementsAttr::get( - resultType.toBuiltinTensor().clone(dtype), apValues); + auto attr = + DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues); rewriter.replaceOpWithNewOp( binder.op, resultType, attr); return success(); @@ -2272,9 +2272,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // Extract the fill value and dtype // ONNX requires value attr to be a tensor if (!attr) { - attr = DenseElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDType), - rewriter.getFloatAttr(resultDType, 0.0)); + attr = + DenseElementsAttr::get(resultType.toBuiltinTensor(), + rewriter.getFloatAttr(resultDType, 0.0)); } // If its a dense resource attr we need to convert to a dense type: diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index aa560402877f..318c2bec361f 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -146,12 +146,11 @@ class ConvertAtenMmOp : public OpConversionPattern { "mismatching contracting dimension for torch.aten.mm")); } - auto resultTy = cast(op.getType()); - auto resultDTy = resultTy.toBuiltinTensor().getElementType(); - Type newResultType = getTypeConverter()->convertType(op.getType()); - Type elementType = cast(newResultType).getElementType(); - auto accumulatorDType = getDefaultAccType(rewriter, resultDTy); - if (accumulatorDType != resultDTy) { + TensorType resultType = + cast(getTypeConverter()->convertType(op.getType())); + Type elementType = resultType.getElementType(); + auto accumulatorDType = getDefaultAccType(rewriter, elementType); + if (accumulatorDType != resultType.getElementType()) { elementType = accumulatorDType; } Value zeroFill = createZeroInitTensor( @@ -197,18 +196,16 @@ class ConvertAtenMmOp : public OpConversionPattern { .getResult(0); } - if (accumulatorDType != resultDTy) { - Type resultElementType = - cast(newResultType).getElementType(); + if (accumulatorDType != resultType.getElementType()) { matmul = torch_to_linalg::convertTensorToElementType( - rewriter, loc, matmul, resultElementType); + rewriter, loc, matmul, resultType.getElementType()); } // When constructed with just dynamic sizes, EmptyOp will have a result // type which has all `?`'s for dimensions, which might not be the result // type of `op`. The constraints on later linalg ops means that the result // of the MatmulOp will have this type too. So cast it to the desired type // so that in the end we have the original result type. - rewriter.replaceOpWithNewOp(op, newResultType, matmul); + rewriter.replaceOpWithNewOp(op, resultType, matmul); return success(); } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 994722f3ea6f..61a0857a8894 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1311,7 +1311,7 @@ static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, return nullptr; auto dty = resultTy.getDtype(); - auto resultBTy = resultTy.toBuiltinTensor().clone(dty); + auto resultBTy = resultTy.toBuiltinTensor(); auto fpTy = dyn_cast(dty); auto intTy = dyn_cast(dty); @@ -1521,7 +1521,7 @@ OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) { if (!ty || !ty.hasDtype() || !ty.hasSizes()) return nullptr; - auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + auto bty = ty.toBuiltinTensor(); if (!bty.hasStaticShape()) return nullptr; @@ -1635,7 +1635,6 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, return nullptr; auto ctx = lhs.getContext(); - auto resultETy = resultTy.getDtype(); auto tensorETy = cast(lhs.getType()).getElementType(); if (lhs.isSplat()) { if (auto intAttr = dyn_cast(rhs)) { @@ -1647,8 +1646,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); auto resultBool = intFolder(tensorAP, scalarAP, unsign); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - resultAP); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP); } if (auto floatAttr = dyn_cast(rhs)) { @@ -1657,8 +1655,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, auto resultBool = fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - resultAP); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP); } return nullptr; } @@ -1681,8 +1678,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, auto resultBool = intFolder(tensorAP, scalarAP, unsign); values.push_back(resultBool); } - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - values); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values); } if (auto floatAttr = dyn_cast(rhs)) { @@ -1693,8 +1689,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); values.push_back(resultBool); } - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - values); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values); } return nullptr; @@ -1844,7 +1839,7 @@ static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand, if (!fpTy && !intTy) return nullptr; - auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype()); + auto resultBTy = resultTy.toBuiltinTensor(); bool splat = operand.isSplat(); bool withinMaxFold = resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold; @@ -2192,7 +2187,7 @@ OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) { return nullptr; auto selfTy = cast(self.getType()); - auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + auto bty = ty.toBuiltinTensor(); if (!bty.hasStaticShape()) return nullptr; @@ -2656,8 +2651,7 @@ LogicalResult AtenSortOp::fold(FoldAdaptor adaptor, if (!indicesTensorType.hasDtype()) return failure(); - auto indicesType = - indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype()); + auto indicesType = indicesTensorType.toBuiltinTensor(); if (!indicesType || !indicesType.hasStaticShape()) return failure(); @@ -3612,9 +3606,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { return nullptr; if (input && input.isSplat()) - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), - input.getSplatValue()); + return DenseElementsAttr::get(outType.toBuiltinTensor(), + input.getSplatValue()); int count = 1; for (auto dim : outType.getSizes()) @@ -3652,8 +3645,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { for (int i = begin; i < limit; i += stride) values.push_back(input.getValues()[i]); - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), values); + return DenseElementsAttr::get(outType.toBuiltinTensor(), values); } // If the input and output shapes are the same we can just fold: @@ -3923,7 +3915,7 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); SmallVector data; if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && @@ -3944,7 +3936,7 @@ OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); int64_t data; if (matchPattern(getT(), m_TorchConstantInt(&data))) { @@ -3964,7 +3956,7 @@ OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); double data; if (matchPattern(getT(), m_TorchConstantFloat(&data))) { @@ -4137,7 +4129,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) { : selfAttr.getValues()[indexInt]; auto dty = resultTy.getDtype(); - auto attrTy = resultTy.toBuiltinTensor().clone(dty); + auto attrTy = resultTy.toBuiltinTensor(); if (auto floatAttr = dyn_cast(splattr)) return DenseElementsAttr::get( attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble())); @@ -4330,7 +4322,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!valueDense.isSplat()) return nullptr; auto splattr = valueDense.getSplatValue(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, splattr); } @@ -4338,7 +4330,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!isa(dty)) return nullptr; int64_t intval = intAttr.getInt(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval)); } @@ -4346,7 +4338,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!isa(dty)) return nullptr; double dblval = fpAttr.getValueAsDouble(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval)); } diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 6735bb37e48b..12aea1589a4d 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -453,12 +453,7 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { } static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { - if (auto floatType = dyn_cast(dtype)) { - return dtype; - } else if (auto integerType = dyn_cast(dtype)) { - return IntegerType::get(context, integerType.getWidth(), - IntegerType::Signless); - } else if (isa(dtype)) { + if (isa(dtype)) { return dtype; } @@ -480,11 +475,11 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { TensorType ValueTensorType::toBuiltinTensor() const { if (!hasDtype()) return nullptr; - if (!hasSizes()) - return UnrankedTensorType::get(getDtype()); Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype()); if (!elementType) return nullptr; + if (!hasSizes()) + return UnrankedTensorType::get(elementType); return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType, getOptionalSparsity()); } diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index deeef0658a52..c4f22715ab34 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -164,7 +164,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion( ConversionTarget &target, TypeConverter &typeConverter) { auto valueTensorTypeConversion = [](Torch::ValueTensorType type) -> std::optional { - return type.toBuiltinTensor(); + auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert any integer type to signless + if (type.getDtype().isInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); + } + + return builtinType; }; setupValueTensorToBuiltinTensorConversion(target, typeConverter, valueTensorTypeConversion); @@ -180,9 +191,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo( auto valueTensorTypeConversion = [](Torch::ValueTensorType type) -> std::optional { auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert signed integer type to signless, keep unsigned as unsigned if (type.getDtype().isUnsignedInteger()) { return builtinType.clone(type.getDtype()); + } else if (type.getDtype().isSignedInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); } + return builtinType; }; setupValueTensorToBuiltinTensorConversion(target, typeConverter,