diff --git a/externals/llvm-project b/externals/llvm-project index eae82ac259ee..5fcf907b3435 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit eae82ac259ee5a58bc4070a414bc53239e18bad0 +Subproject commit 5fcf907b34355980f77d7665a175b05fea7a6b7b diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 30cc4db44181..2891a22eb817 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -81,7 +81,7 @@ class AdjustCallingConventionForFunc } newResultTypes.push_back(type); } - rewriter.updateRootInPlace(func, [&] { + rewriter.modifyOpInPlace(func, [&] { func.setType(FunctionType::get( getContext(), conversion.getConvertedTypes(), newResultTypes)); // Clear out the type bounds, now that the type incorporates them. @@ -194,14 +194,12 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion( - [](Torch::TupleType type, - SmallVectorImpl &types) -> LogicalResult { + [](Torch::TupleType type, SmallVectorImpl &types) -> LogicalResult { llvm::append_range(types, type.getContainedTypes()); return success(); }); typeConverter.addConversion( - [](Torch::NoneType type, - SmallVectorImpl &types) -> LogicalResult { + [](Torch::NoneType type, SmallVectorImpl &types) -> LogicalResult { return success(); }); diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index cd76275a745d..7db6bc6776b3 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -175,7 +175,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock // Replace return type of view-like ops with value-semantics type variant. for (Operation *viewLikeOp : ops.viewLikeOps) { - rewriter.updateRootInPlace(viewLikeOp, [&] { + rewriter.modifyOpInPlace(viewLikeOp, [&] { Value result = viewLikeOp->getResult(0); auto resultType = result.getType().dyn_cast(); if (resultType) @@ -337,7 +337,7 @@ class RewriteViewLikeSubgraph // correctly copy them back to their mlir::func::ReturnOp's expected types. DenseMap originalTypes; for (Operation *op : viewLikeOps) { - rewriter.updateRootInPlace(op, [&]() { + rewriter.modifyOpInPlace(op, [&]() { if (auto nonValueTensorType = op->getResult(0).getType().dyn_cast()) { originalTypes[op->getResult(0)] = nonValueTensorType; diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 8ba0479625d8..200f25c82c43 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -9,10 +9,10 @@ #include "PassDetail.h" +#include "ReifyAbstractInterpCalculationsUtils.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" -#include "ReifyAbstractInterpCalculationsUtils.h" #include "llvm/ADT/StringExtras.h" using namespace mlir; @@ -72,8 +72,8 @@ namespace { // immutable tensors. class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { public: - ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context, - const std::optional& extraLibrary) + ConvertHasValueSemanticsOpsToValueTensors( + MLIRContext *context, const std::optional &extraLibrary) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) { this->extraLibrary = extraLibrary; } @@ -87,7 +87,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { return rewriter.notifyMatchFailure(op, "does not have value semantics"); } - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); // Convert all operands. SmallVector newOperands; for (OpOperand &opOperand : op->getOpOperands()) { @@ -105,7 +105,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { auto listConstruct = opOperand.get().getDefiningOp(); if (!listConstruct) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: list of non vtensor type not constructed " "from list construct"); @@ -120,7 +120,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { if (!llvm::all_of(listConstruct.getElements(), [](Value val) { return val.getType().isa(); })) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: list containing optional type is not " "handled."); @@ -138,7 +138,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { Type newListType = getContainerOrTensorTypeWithValueSemantics(listType); if (!newListType) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "Unable to convert list type to value semantics."); } @@ -154,7 +154,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { // from the non value tensor of the original optional value. auto derefine = opOperand.get().getDefiningOp(); if (!derefine) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: optional of non vtensor type not from " "derefine"); @@ -180,9 +180,10 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { rewriter.create(op->getLoc(), result); result.replaceAllUsesExcept(nonValueTensor, nonValueTensor); } - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); return success(); } + private: std::optional extraLibrary; }; @@ -290,9 +291,9 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { Operation *newOp = rewriter.create(state); // Note: need to convert result to first input's dtype because mix precision // compute would result in different behaviors. - // For example: - // a = torch.randn(3, 3).half() # float16 - // b = torch.randn(3, 3) # float32 + // For example: + // a = torch.randn(3, 3).half() # float16 + // b = torch.randn(3, 3) # float32 // a += b # i.e. torch.ops.aten.add_(a, b), result is float16 // c = a + b # i.e. torch.ops.aten.add(a, b), result is float32 Value none = rewriter.create(op->getLoc()); @@ -300,7 +301,8 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { auto aDtype = rewriter.create(op->getLoc(), op->getOperand(0)); auto toDtype = rewriter.create( op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0), - aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); + aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); auto tensor = rewriter.create(op->getLoc(), toDtype); createOverwriteTensorContents(rewriter, op->getLoc(), tensor, op->getOperand(0));