Skip to content

Commit 0bcc6d1

Browse files
authored
Add maximize-value-semantics support for multiple non-value tensor inputs (#659)
This commit adds value semantics support for ops such as `aten.view_as` and `aten.expand_as` that take two non-value tensors as input.
1 parent 92da498 commit 0bcc6d1

File tree

4 files changed

+123
-68
lines changed

4 files changed

+123
-68
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2859,6 +2859,20 @@ def Torch_AtenExpandOp : Torch_Op<"aten.expand", [
28592859
let assemblyFormat = "$self `,` $size `,` $implicit attr-dict `:` qualified(type($self)) `,` qualified(type($size)) `,` qualified(type($implicit)) `->` qualified(type($result))";
28602860
}
28612861

2862+
def Torch_AtenExpandAsOp : Torch_Op<"aten.expand_as", [
2863+
AllowsTypeRefinement
2864+
]> {
2865+
let summary = "Generated op for `aten::expand_as : (Tensor, Tensor) -> (Tensor)`";
2866+
let arguments = (ins
2867+
AnyTorchTensorType:$self,
2868+
AnyTorchTensorType:$other
2869+
);
2870+
let results = (outs
2871+
AnyTorchTensorType:$result
2872+
);
2873+
let assemblyFormat = "$self `,` $other attr-dict `:` qualified(type($self)) `,` qualified(type($other)) `->` qualified(type($result))";
2874+
}
2875+
28622876
def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [
28632877
AllowsTypeRefinement,
28642878
ReadOnly

lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp

Lines changed: 84 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,18 @@ using namespace mlir;
2121
using namespace mlir::torch;
2222
using namespace mlir::torch::Torch;
2323

24+
static Value assertNonValueTensor(Value tensor) {
25+
assert(tensor.getType().isa<NonValueTensorType>() &&
26+
"tensor is expected to be a non-value tensor");
27+
return tensor;
28+
}
29+
2430
static bool isViewLikeOp(Operation *op) {
2531
// AtenContiguousOp might return a view, so this is conservatively
2632
// correct. We could potentially be more precise and identify the cases
2733
// that it does not return a view and treat those as having value
2834
// semantics.
29-
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenExpandOp,
35+
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenExpandAsOp, AtenExpandOp,
3036
AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
3137
AtenSelectIntOp, AtenSliceTensorOp, AtenSqueezeDimOp,
3238
AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp,
@@ -39,6 +45,8 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
3945
public:
4046
using OpRewritePattern::OpRewritePattern;
4147

48+
// Used to represent all of the interpreted ops that have at least
49+
// one non-value tensor as input or output.
4250
struct InterpretedOps {
4351
SmallVector<Operation *> copyLikeOps;
4452
SmallVector<Operation *> viewLikeOps;
@@ -50,69 +58,52 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
5058
// interpretation within a single basic block. If rewriting is
5159
// possible, the interpreted ops are returned split into their
5260
// respective categories.
53-
static FailureOr<InterpretedOps>
54-
abstractlyInterpretSlice(CopyToNonValueTensorOp copyToNonValueTensor,
55-
SmallVector<Operation *> nonValueTensorUsers,
56-
PatternRewriter &rewriter) {
61+
static FailureOr<InterpretedOps> abstractlyInterpretSlice(
62+
CopyToNonValueTensorOp copyToNonValueTensor,
63+
const DenseMap<Operation *, SmallVector<Value>> &nonValueTensorsUsedByOp,
64+
PatternRewriter &rewriter) {
5765
// Sort by order in the block, so we can abstractly interpret the ops.
66+
SmallVector<Operation *> nonValueTensorUsers(
67+
llvm::make_first_range(nonValueTensorsUsedByOp));
5868
llvm::sort(nonValueTensorUsers, [](Operation *lhs, Operation *rhs) {
5969
return lhs->isBeforeInBlock(rhs);
6070
});
6171

6272
// We track the available aliases at each point as well as split the
6373
// users into view-like, copy-to-value, and overwrite ops as we walk
6474
// forward.
65-
//
66-
// We also need to track all seen aliases to make sure that we only rewrite
67-
// those operands of a ReturnOp, if present (a ReturnOp can return tensors
68-
// from multiple different slices).
6975
InterpretedOps result;
7076
result.copyLikeOps.push_back(copyToNonValueTensor);
71-
DenseSet<Value> availableAliases{copyToNonValueTensor.result()};
72-
DenseSet<Value> seenAliases{copyToNonValueTensor.result()};
77+
DenseSet<Value> availableAliases{
78+
assertNonValueTensor(copyToNonValueTensor.result())};
7379
for (Operation *user : nonValueTensorUsers) {
74-
if (isViewLikeOp(user)) {
75-
Value operand = user->getOperand(0);
80+
for (Value operand : nonValueTensorsUsedByOp.lookup(user)) {
7681
if (!availableAliases.contains(operand)) {
7782
return rewriter.notifyMatchFailure(
7883
copyToNonValueTensor,
79-
"operand of view-like op is not a valid tensor alias");
84+
"operand of op is not a valid tensor alias");
8085
}
81-
86+
}
87+
if (isViewLikeOp(user)) {
88+
Value userResult = user->getResult(0);
8289
// View-like ops produce a new alias available to later ops.
83-
availableAliases.insert(user->getResult(0));
84-
seenAliases.insert(user->getResult(0));
90+
// However, if the view-like op has been partially converted
91+
// to use value semantics (which happens for example with ops
92+
// that take two aliases as input), then it is possible that the
93+
// op no longer generates an alias.
94+
if (userResult.getType().isa<NonValueTensorType>())
95+
availableAliases.insert(userResult);
8596
result.viewLikeOps.push_back(user);
8697
} else if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
87-
if (!availableAliases.contains(copyToValueTensor.operand())) {
88-
return rewriter.notifyMatchFailure(
89-
copyToNonValueTensor,
90-
"operand of copyToValueTensorOp is not a valid tensor alias");
91-
}
9298
result.copyLikeOps.push_back(copyToValueTensor);
9399
} else if (auto overwrite = dyn_cast<OverwriteTensorContentsOp>(user)) {
94-
Value overwritten = overwrite.overwritten();
95-
if (!availableAliases.contains(overwritten)) {
96-
return rewriter.notifyMatchFailure(
97-
copyToNonValueTensor, "overwritten tensor is not a valid alias");
98-
}
99-
100100
// To simplify the analysis, we only support the case where the
101101
// only aliases used after an overwrite are the aliases generated
102102
// after plus the alias being overwritten.
103103
availableAliases.clear();
104-
availableAliases.insert(overwritten);
104+
availableAliases.insert(assertNonValueTensor(overwrite.overwritten()));
105105
result.overwriteTensorContentsOps.push_back(overwrite);
106106
} else if (auto returnOp = dyn_cast<mlir::ReturnOp>(user)) {
107-
for (Value operand : returnOp->getOperands()) {
108-
if (!seenAliases.contains(operand))
109-
continue;
110-
if (!availableAliases.contains(operand)) {
111-
return rewriter.notifyMatchFailure(
112-
copyToNonValueTensor,
113-
"operand of ReturnOp is not a valid tensor alias");
114-
}
115-
}
116107
result.returnOp = returnOp;
117108
} else {
118109
return rewriter.notifyMatchFailure(
@@ -146,10 +137,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
146137
// overwritten tensor.
147138
for (OverwriteTensorContentsOp overwrite :
148139
llvm::reverse(ops.overwriteTensorContentsOps)) {
149-
Value overwritten = overwrite.overwritten();
150-
assert(overwritten.getType().dyn_cast<NonValueTensorType>() &&
151-
"the analysis assumes that overwritten remains a nonValueTensor "
152-
"throughout the rewriting");
140+
Value overwritten = assertNonValueTensor(overwrite.overwritten());
153141
overwritten.replaceUsesWithIf(
154142
overwrite.value(), [&](const OpOperand &operand) {
155143
return !operand.getOwner()->isBeforeInBlock(overwrite);
@@ -165,9 +153,8 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
165153
rewriter.updateRootInPlace(viewLikeOp, [&] {
166154
Value result = viewLikeOp->getResult(0);
167155
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
168-
assert(resultType && "all view-like ops considered must have result of "
169-
"type `NonValueTensorType` before rewriting");
170-
result.setType(resultType.getWithValueSemantics());
156+
if (resultType)
157+
result.setType(resultType.getWithValueSemantics());
171158
});
172159
}
173160
if (ops.returnOp.hasValue()) {
@@ -192,47 +179,76 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
192179
// terminating at CopyToValueTensorOp's, possibly with intervening view-like
193180
// ops and overwrites. This also catches the special case of a
194181
// CopyToNonValueTensorOp that trivially feeds into CopyToValueTensorOp's.
195-
SmallVector<Operation *> nonValueTensorUsers;
196-
auto workList = llvm::to_vector(copy.result().getUsers());
182+
DenseMap<Operation *, SmallVector<Value>> nonValueTensorsUsedByOp;
183+
184+
// Some view-like ops take more than one non-value tensor as input (such as
185+
// `aten.view_as`). For these ops, we assume that the tensor view that gets
186+
// returned by the op is a view of the first operand of the op.
187+
188+
// View-like ops that return a non-value tensor and have a view of the
189+
// operand of `copy.to_tensor` as the first operand.
190+
DenseSet<Operation *> validViewLikeOps;
191+
// View-like ops that return a non-value tensor and have a view of the
192+
// operand of `copy.to_tensor` as an operand other than the first operand.
193+
DenseSet<Operation *> viewLikeOpsToCheck;
194+
195+
using OpOperandRefs = SmallVector<std::reference_wrapper<OpOperand>>;
196+
OpOperandRefs workList(copy.result().getUses());
197197
while (!workList.empty()) {
198-
Operation *op = workList.pop_back_val();
198+
OpOperand &operand = workList.pop_back_val();
199+
Operation *op = operand.getOwner();
199200
if (op->getBlock() != copy->getBlock()) {
200201
return rewriter.notifyMatchFailure(
201202
copy, "can only analyze within a single basic block");
202203
}
203-
nonValueTensorUsers.push_back(op);
204204

205205
if (isViewLikeOp(op)) {
206-
auto isTensor = [](const Value operand) {
207-
return operand.getType().isa<BaseTensorType>();
208-
};
209-
210-
// We currently only support view-like ops with one tensor input and one
211-
// tensor output, meaning that the tensor use-def chains form a tree.
212-
// This will not be the case for an op like `torch.aten.view_as`, so
213-
// we will need to add a set to prune duplicate visitation.
214-
if (llvm::count_if(op->getOperands(), isTensor) != 1 ||
215-
llvm::count_if(op->getResults(), isTensor) != 1 ||
216-
!isTensor(op->getOperand(0)) || !isTensor(op->getResult(0))) {
206+
// We currently only support view-like ops with one tensor output.
207+
if (op->getNumResults() != 1 ||
208+
!op->getResult(0).getType().isa<BaseTensorType>()) {
217209
return rewriter.notifyMatchFailure(
218-
copy, "unsupported: view-like ops must have one tensor input and "
219-
"one tensor output, and the tensor input/output must be "
220-
"the first operand/result");
210+
copy, "unsupported: view-like ops must have one tensor output, "
211+
"and the tensor output must be the first result");
221212
}
222213

223-
llvm::append_range(workList, op->getResult(0).getUsers());
214+
Value opResult = op->getResult(0);
215+
// There are cases where a view-like op will be partially converted to
216+
// value semantics, resulting in at least one of the inputs being a
217+
// non-value tensor and the output being a value tensor. If this is the
218+
// case then there is no need to look at the users of the result of the
219+
// op.
220+
if (opResult.getType().isa<NonValueTensorType>()) {
221+
if (operand.getOperandNumber() == 0) {
222+
validViewLikeOps.insert(op);
223+
llvm::append_range(workList, opResult.getUses());
224+
} else {
225+
viewLikeOpsToCheck.insert(op);
226+
}
227+
}
224228
}
229+
230+
nonValueTensorsUsedByOp[op].push_back(
231+
assertNonValueTensor(operand.get()));
225232
}
226233

227234
// Nothing to do if there is just a ReturnOp -- we know that we won't be
228235
// rewriting anything, since we must preserve the ReturnOp's original type.
229-
if (llvm::hasSingleElement(nonValueTensorUsers) &&
230-
isa<mlir::ReturnOp>(nonValueTensorUsers[0])) {
236+
if (llvm::hasSingleElement(nonValueTensorsUsedByOp) &&
237+
isa<mlir::ReturnOp>(nonValueTensorsUsedByOp.begin()->first)) {
231238
return failure();
232239
}
233240

234-
FailureOr<InterpretedOps> interpretedOps = abstractlyInterpretSlice(
235-
copy, std::move(nonValueTensorUsers), rewriter);
241+
if (llvm::any_of(viewLikeOpsToCheck, [&](Operation *op) {
242+
return !validViewLikeOps.contains(op);
243+
})) {
244+
return rewriter.notifyMatchFailure(
245+
copy, "if a view-like op returns a non-value tensor, the first "
246+
"operand must be a view of the operand of the `copy.to_tensor` "
247+
"op");
248+
}
249+
250+
FailureOr<InterpretedOps> interpretedOps =
251+
abstractlyInterpretSlice(copy, nonValueTensorsUsedByOp, rewriter);
236252
if (failed(LogicalResult(interpretedOps)))
237253
return failure();
238254
rewriteSlice(*interpretedOps, rewriter);

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def emit_with_mutating_variants(key, **kwargs):
397397
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
398398
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
399399
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
400+
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
400401
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
401402
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
402403
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")

test/Dialect/Torch/maximize-value-semantics.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,27 @@ func @viewlike$unmodeled_op(%arg0: !torch.vtensor) -> !torch.vtensor {
220220
%2 = torch.copy.to_vtensor %1 : !torch.vtensor
221221
return %2 : !torch.vtensor
222222
}
223+
224+
// CHECK-LABEL: func @viewlike$two_inputs_one_copy(
225+
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
226+
// CHECK: %[[EXPAND_AS:.*]] = torch.aten.expand_as %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor -> !torch.vtensor
227+
// CHECK: return %[[EXPAND_AS]] : !torch.vtensor
228+
func @viewlike$two_inputs_one_copy(%arg0: !torch.vtensor) -> !torch.vtensor {
229+
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
230+
%1 = torch.aten.expand_as %0, %0 : !torch.tensor, !torch.tensor -> !torch.tensor
231+
%2 = torch.copy.to_vtensor %1 : !torch.vtensor
232+
return %2 : !torch.vtensor
233+
}
234+
235+
// CHECK-LABEL: func @viewlike$two_inputs_two_copies(
236+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
237+
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
238+
// CHECK: %[[EXPAND_AS:.*]] = torch.aten.expand_as %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor -> !torch.vtensor
239+
// CHECK: return %[[EXPAND_AS]] : !torch.vtensor
240+
func @viewlike$two_inputs_two_copies(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
241+
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
242+
%1 = torch.copy.to_tensor %arg1 : !torch.tensor
243+
%2 = torch.aten.expand_as %0, %1 : !torch.tensor, !torch.tensor -> !torch.tensor
244+
%3 = torch.copy.to_vtensor %2 : !torch.vtensor
245+
return %3 : !torch.vtensor
246+
}

0 commit comments

Comments
 (0)