@@ -21,12 +21,18 @@ using namespace mlir;
2121using namespace mlir ::torch;
2222using namespace mlir ::torch::Torch;
2323
24+ static Value assertNonValueTensor (Value tensor) {
25+ assert (tensor.getType ().isa <NonValueTensorType>() &&
26+ " tensor is expected to be a non-value tensor" );
27+ return tensor;
28+ }
29+
2430static bool isViewLikeOp (Operation *op) {
2531 // AtenContiguousOp might return a view, so this is conservatively
2632 // correct. We could potentially be more precise and identify the cases
2733 // that it does not return a view and treat those as having value
2834 // semantics.
29- return isa<AtenBroadcastToOp, AtenContiguousOp, AtenExpandOp,
35+ return isa<AtenBroadcastToOp, AtenContiguousOp, AtenExpandAsOp, AtenExpandOp,
3036 AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
3137 AtenSelectIntOp, AtenSliceTensorOp, AtenSqueezeDimOp,
3238 AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp,
@@ -39,6 +45,8 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
3945public:
4046 using OpRewritePattern::OpRewritePattern;
4147
48+ // Used to represent all of the interpreted ops that have at least
49+ // one non-value tensor as input or output.
4250 struct InterpretedOps {
4351 SmallVector<Operation *> copyLikeOps;
4452 SmallVector<Operation *> viewLikeOps;
@@ -50,69 +58,52 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
5058 // interpretation within a single basic block. If rewriting is
5159 // possible, the interpreted ops are returned split into their
5260 // respective categories.
53- static FailureOr<InterpretedOps>
54- abstractlyInterpretSlice ( CopyToNonValueTensorOp copyToNonValueTensor,
55- SmallVector <Operation *> nonValueTensorUsers ,
56- PatternRewriter &rewriter) {
61+ static FailureOr<InterpretedOps> abstractlyInterpretSlice (
62+ CopyToNonValueTensorOp copyToNonValueTensor,
63+ const DenseMap <Operation *, SmallVector<Value>> &nonValueTensorsUsedByOp ,
64+ PatternRewriter &rewriter) {
5765 // Sort by order in the block, so we can abstractly interpret the ops.
66+ SmallVector<Operation *> nonValueTensorUsers (
67+ llvm::make_first_range (nonValueTensorsUsedByOp));
5868 llvm::sort (nonValueTensorUsers, [](Operation *lhs, Operation *rhs) {
5969 return lhs->isBeforeInBlock (rhs);
6070 });
6171
6272 // We track the available aliases at each point as well as split the
6373 // users into view-like, copy-to-value, and overwrite ops as we walk
6474 // forward.
65- //
66- // We also need to track all seen aliases to make sure that we only rewrite
67- // those operands of a ReturnOp, if present (a ReturnOp can return tensors
68- // from multiple different slices).
6975 InterpretedOps result;
7076 result.copyLikeOps .push_back (copyToNonValueTensor);
71- DenseSet<Value> availableAliases{copyToNonValueTensor. result ()};
72- DenseSet<Value> seenAliases{ copyToNonValueTensor.result ()};
77+ DenseSet<Value> availableAliases{
78+ assertNonValueTensor ( copyToNonValueTensor.result () )};
7379 for (Operation *user : nonValueTensorUsers) {
74- if (isViewLikeOp (user)) {
75- Value operand = user->getOperand (0 );
80+ for (Value operand : nonValueTensorsUsedByOp.lookup (user)) {
7681 if (!availableAliases.contains (operand)) {
7782 return rewriter.notifyMatchFailure (
7883 copyToNonValueTensor,
79- " operand of view-like op is not a valid tensor alias" );
84+ " operand of op is not a valid tensor alias" );
8085 }
81-
86+ }
87+ if (isViewLikeOp (user)) {
88+ Value userResult = user->getResult (0 );
8289 // View-like ops produce a new alias available to later ops.
83- availableAliases.insert (user->getResult (0 ));
84- seenAliases.insert (user->getResult (0 ));
90+ // However, if the view-like op has been partially converted
91+ // to use value semantics (which happens for example with ops
92+ // that take two aliases as input), then it is possible that the
93+ // op no longer generates an alias.
94+ if (userResult.getType ().isa <NonValueTensorType>())
95+ availableAliases.insert (userResult);
8596 result.viewLikeOps .push_back (user);
8697 } else if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
87- if (!availableAliases.contains (copyToValueTensor.operand ())) {
88- return rewriter.notifyMatchFailure (
89- copyToNonValueTensor,
90- " operand of copyToValueTensorOp is not a valid tensor alias" );
91- }
9298 result.copyLikeOps .push_back (copyToValueTensor);
9399 } else if (auto overwrite = dyn_cast<OverwriteTensorContentsOp>(user)) {
94- Value overwritten = overwrite.overwritten ();
95- if (!availableAliases.contains (overwritten)) {
96- return rewriter.notifyMatchFailure (
97- copyToNonValueTensor, " overwritten tensor is not a valid alias" );
98- }
99-
100100 // To simplify the analysis, we only support the case where the
101101 // only aliases used after an overwrite are the aliases generated
102102 // after plus the alias being overwritten.
103103 availableAliases.clear ();
104- availableAliases.insert (overwritten);
104+ availableAliases.insert (assertNonValueTensor (overwrite. overwritten ()) );
105105 result.overwriteTensorContentsOps .push_back (overwrite);
106106 } else if (auto returnOp = dyn_cast<mlir::ReturnOp>(user)) {
107- for (Value operand : returnOp->getOperands ()) {
108- if (!seenAliases.contains (operand))
109- continue ;
110- if (!availableAliases.contains (operand)) {
111- return rewriter.notifyMatchFailure (
112- copyToNonValueTensor,
113- " operand of ReturnOp is not a valid tensor alias" );
114- }
115- }
116107 result.returnOp = returnOp;
117108 } else {
118109 return rewriter.notifyMatchFailure (
@@ -146,10 +137,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
146137 // overwritten tensor.
147138 for (OverwriteTensorContentsOp overwrite :
148139 llvm::reverse (ops.overwriteTensorContentsOps )) {
149- Value overwritten = overwrite.overwritten ();
150- assert (overwritten.getType ().dyn_cast <NonValueTensorType>() &&
151- " the analysis assumes that overwritten remains a nonValueTensor "
152- " throughout the rewriting" );
140+ Value overwritten = assertNonValueTensor (overwrite.overwritten ());
153141 overwritten.replaceUsesWithIf (
154142 overwrite.value (), [&](const OpOperand &operand) {
155143 return !operand.getOwner ()->isBeforeInBlock (overwrite);
@@ -165,9 +153,8 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
165153 rewriter.updateRootInPlace (viewLikeOp, [&] {
166154 Value result = viewLikeOp->getResult (0 );
167155 auto resultType = result.getType ().dyn_cast <NonValueTensorType>();
168- assert (resultType && " all view-like ops considered must have result of "
169- " type `NonValueTensorType` before rewriting" );
170- result.setType (resultType.getWithValueSemantics ());
156+ if (resultType)
157+ result.setType (resultType.getWithValueSemantics ());
171158 });
172159 }
173160 if (ops.returnOp .hasValue ()) {
@@ -192,47 +179,76 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
192179 // terminating at CopyToValueTensorOp's, possibly with intervening view-like
193180 // ops and overwrites. This also catches the special case of a
194181 // CopyToNonValueTensorOp that trivially feeds into CopyToValueTensorOp's.
195- SmallVector<Operation *> nonValueTensorUsers;
196- auto workList = llvm::to_vector (copy.result ().getUsers ());
182+ DenseMap<Operation *, SmallVector<Value>> nonValueTensorsUsedByOp;
183+
184+ // Some view-like ops take more than one non-value tensor as input (such as
185+ // `aten.view_as`). For these ops, we assume that the tensor view that gets
186+ // returned by the op is a view of the first operand of the op.
187+
188+ // View-like ops that return a non-value tensor and have a view of the
189+ // operand of `copy.to_tensor` as the first operand.
190+ DenseSet<Operation *> validViewLikeOps;
191+ // View-like ops that return a non-value tensor and have a view of the
192+ // operand of `copy.to_tensor` as an operand other than the first operand.
193+ DenseSet<Operation *> viewLikeOpsToCheck;
194+
195+ using OpOperandRefs = SmallVector<std::reference_wrapper<OpOperand>>;
196+ OpOperandRefs workList (copy.result ().getUses ());
197197 while (!workList.empty ()) {
198- Operation *op = workList.pop_back_val ();
198+ OpOperand &operand = workList.pop_back_val ();
199+ Operation *op = operand.getOwner ();
199200 if (op->getBlock () != copy->getBlock ()) {
200201 return rewriter.notifyMatchFailure (
201202 copy, " can only analyze within a single basic block" );
202203 }
203- nonValueTensorUsers.push_back (op);
204204
205205 if (isViewLikeOp (op)) {
206- auto isTensor = [](const Value operand) {
207- return operand.getType ().isa <BaseTensorType>();
208- };
209-
210- // We currently only support view-like ops with one tensor input and one
211- // tensor output, meaning that the tensor use-def chains form a tree.
212- // This will not be the case for an op like `torch.aten.view_as`, so
213- // we will need to add a set to prune duplicate visitation.
214- if (llvm::count_if (op->getOperands (), isTensor) != 1 ||
215- llvm::count_if (op->getResults (), isTensor) != 1 ||
216- !isTensor (op->getOperand (0 )) || !isTensor (op->getResult (0 ))) {
206+ // We currently only support view-like ops with one tensor output.
207+ if (op->getNumResults () != 1 ||
208+ !op->getResult (0 ).getType ().isa <BaseTensorType>()) {
217209 return rewriter.notifyMatchFailure (
218- copy, " unsupported: view-like ops must have one tensor input and "
219- " one tensor output, and the tensor input/output must be "
220- " the first operand/result" );
210+ copy, " unsupported: view-like ops must have one tensor output, "
211+ " and the tensor output must be the first result" );
221212 }
222213
223- llvm::append_range (workList, op->getResult (0 ).getUsers ());
214+ Value opResult = op->getResult (0 );
215+ // There are cases where a view-like op will be partially converted to
216+ // value semantics, resulting in at least one of the inputs being a
217+ // non-value tensor and the output being a value tensor. If this is the
218+ // case then there is no need to look at the users of the result of the
219+ // op.
220+ if (opResult.getType ().isa <NonValueTensorType>()) {
221+ if (operand.getOperandNumber () == 0 ) {
222+ validViewLikeOps.insert (op);
223+ llvm::append_range (workList, opResult.getUses ());
224+ } else {
225+ viewLikeOpsToCheck.insert (op);
226+ }
227+ }
224228 }
229+
230+ nonValueTensorsUsedByOp[op].push_back (
231+ assertNonValueTensor (operand.get ()));
225232 }
226233
227234 // Nothing to do if there is just a ReturnOp -- we know that we won't be
228235 // rewriting anything, since we must preserve the ReturnOp's original type.
229- if (llvm::hasSingleElement (nonValueTensorUsers ) &&
230- isa<mlir::ReturnOp>(nonValueTensorUsers[ 0 ] )) {
236+ if (llvm::hasSingleElement (nonValueTensorsUsedByOp ) &&
237+ isa<mlir::ReturnOp>(nonValueTensorsUsedByOp. begin ()-> first )) {
231238 return failure ();
232239 }
233240
234- FailureOr<InterpretedOps> interpretedOps = abstractlyInterpretSlice (
235- copy, std::move (nonValueTensorUsers), rewriter);
241+ if (llvm::any_of (viewLikeOpsToCheck, [&](Operation *op) {
242+ return !validViewLikeOps.contains (op);
243+ })) {
244+ return rewriter.notifyMatchFailure (
245+ copy, " if a view-like op returns a non-value tensor, the first "
246+ " operand must be a view of the operand of the `copy.to_tensor` "
247+ " op" );
248+ }
249+
250+ FailureOr<InterpretedOps> interpretedOps =
251+ abstractlyInterpretSlice (copy, nonValueTensorsUsedByOp, rewriter);
236252 if (failed (LogicalResult (interpretedOps)))
237253 return failure ();
238254 rewriteSlice (*interpretedOps, rewriter);
0 commit comments