Skip to content

Commit

Permalink
Add support for multiple dynamic reassociation dims for unflatten.int (
Browse files Browse the repository at this point in the history
…#3504)

Addresses an issue with onnx.Gather lowering to linalg:
<nod-ai/SHARK-ModelDev#242>

The builder for tensor.expand_shape, without an explicitly provided
output shape, fails to infer an output shape in the case of multiple
dynamic reassociation dims. I tried adding the output shape explicitly
for tensor.expand_shape, but ran into compilation issues later on (see
<iree-org/iree#17760>).

This PR adds support by lowering this op to tensor.reshape when multiple
dynamic reassociation dims are provided.
  • Loading branch information
zjgarvey authored Jun 28, 2024
1 parent a1c4089 commit af236da
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 26 deletions.
72 changes: 57 additions & 15 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,8 @@ class ConvertAtenUnflattenIntOp
"Expected input type having sizes");
}
int inputRank = inputTensorType.getSizes().size();
int outputRank = outputTensorType.getSizes().size();
auto outputSizes = outputTensorType.getSizes();
int outputRank = outputSizes.size();

int64_t dimInt;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)))
Expand All @@ -675,23 +676,64 @@ class ConvertAtenUnflattenIntOp
auto sizesOp = op.getSizes().getDefiningOp<Torch::PrimListConstructOp>();
int numSizes = sizesOp.getNumOperands();

SmallVector<ReassociationIndices> reassociations(inputRank);
if (inputRank > 0) {
for (int i = 0; i < dimInt; ++i)
reassociations[i].push_back(i);

for (int i = 0; i < numSizes; ++i)
reassociations[dimInt].push_back(i + dimInt);

for (int i = dimInt + numSizes; i < outputRank; ++i)
reassociations[i - numSizes + 1].push_back(i);
int64_t numDynamicReassocDims = 0;
for (int64_t i = dimInt; i < dimInt + numSizes; i++) {
if (outputSizes[i] == Torch::kUnknownSize)
numDynamicReassocDims++;
}

SmallVector<Value> reassocSizes;
if (!getListConstructElements(op.getSizes(), reassocSizes) &&
numDynamicReassocDims > 1)
return rewriter.notifyMatchFailure(
op, "Must be able to either infer expansion dims, or retrieve them "
"from list construct");

auto expandTy = getTypeConverter()->convertType(outputTensorType);
auto expand = rewriter
.create<tensor::ExpandShapeOp>(
loc, expandTy, adaptor.getSelf(), reassociations)
.getResult();
Value expand;
// When there are less than two dynamic reassociation dims, this will lower
// to tensor.expand_shape. Otherwise, this lowers to tensor.reshape.
// TODO: in the numDynamicReassocDims >= 2 case, lower to expand_shape with
// explicitly provided outputShape once
// https://github.com/iree-org/iree/issues/17760 is resolved.
if (numDynamicReassocDims < 2) {
SmallVector<ReassociationIndices> reassociations(inputRank);
if (inputRank > 0) {
for (int i = 0; i < dimInt; ++i)
reassociations[i].push_back(i);
for (int i = 0; i < numSizes; ++i)
reassociations[dimInt].push_back(i + dimInt);
for (int i = dimInt + numSizes; i < outputRank; ++i)
reassociations[i - numSizes + 1].push_back(i);
}
expand = rewriter
.create<tensor::ExpandShapeOp>(
loc, expandTy, adaptor.getSelf(), reassociations)
.getResult();
} else {
reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
reassocSizes);
SmallVector<Value> inputShape =
getTensorSizes(rewriter, loc, adaptor.getSelf());
inputShape = castIndexVectorToInt64Vector(rewriter, loc, inputShape);
SmallVector<Value> outputShape(inputShape.begin(),
inputShape.begin() + dimInt);
if (inputRank > 0) {
for (int i = 0; i < numSizes; ++i)
outputShape.push_back(reassocSizes[i]);
for (int i = dimInt + numSizes; i < outputRank; ++i)
outputShape.push_back(inputShape[i - numSizes + 1]);
}

RankedTensorType shapeType = RankedTensorType::get(
ArrayRef<int64_t>{outputRank}, rewriter.getIntegerType(64));
Value shapeValue =
rewriter.create<tensor::FromElementsOp>(loc, shapeType, outputShape);
expand = rewriter
.create<tensor::ReshapeOp>(loc, expandTy, adaptor.getSelf(),
shapeValue)
.getResult();
}
rewriter.replaceOp(op, expand);
return success();
}
Expand Down
11 changes: 0 additions & 11 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,17 +2197,6 @@
ONNX_XFAIL_SET = {
# Failure - cast error
"PermuteNegativeIndexModule_basic",
# Failure - expand multiple dynamic dims
"EmbeddingModuleF16_basic",
"EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic",
"IndexTensorHackedTwinModule3dInput_basic",
"IndexTensorHackedTwinModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
"IndexTensorSelectDimModule_basic",
# Failure - incorrect numerics
"AvgPool2dDivisorOverrideModule_basic",
"BroadcastDynamicDimModule_basic",
Expand Down
27 changes: 27 additions & 0 deletions test/Conversion/TorchToLinalg/view.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,30 @@ func.func @torch.aten.view$dynamicInferredSame(%arg0: !torch.vtensor<[10,?,2,3],
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[10,?,2,3],f32>, !torch.list<int> -> !torch.vtensor<[2,5,?,6],f32>
return %1 : !torch.vtensor<[2,5,?,6],f32>
}

// -----

// this is to check a path for unflatten.int with two dynamic reassociation dims
// the IR here is generated from the onnx.Gather conversion
// CHECK-LABEL: @gather_graph
// CHECK: %[[fromelt:.*]] = tensor.from_elements
// CHECK-SAME: tensor<3xi64>
// CHECK: %[[reshape:.*]] = tensor.reshape
// CHECK-SAME: (tensor<?x3xf32>, tensor<3xi64>) -> tensor<?x?x3xf32>
func.func @gather_graph(%arg0: !torch.vtensor<[5,3],f32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?,3],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%int-1 = torch.constant.int -1
%int5 = torch.constant.int 5
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.lt.Scalar %arg1, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],i1>
%1 = torch.aten.add.Scalar %arg1, %int5, %int1 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?,?],si64>
%2 = torch.aten.where.self %0, %1, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
%3 = torch.aten.size.int %2, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%4 = torch.aten.size.int %2, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
%5 = torch.prim.ListConstruct %3, %4 : (!torch.int, !torch.int) -> !torch.list<int>
%6 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%7 = torch.aten.view %2, %6 : !torch.vtensor<[?,?],si64>, !torch.list<int> -> !torch.vtensor<[?],si64>
%8 = torch.aten.index_select %arg0, %int0, %7 : !torch.vtensor<[5,3],f32>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,3],f32>
%9 = torch.aten.unflatten.int %8, %int0, %5 : !torch.vtensor<[?,3],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,3],f32>
return %9 : !torch.vtensor<[?,?,3],f32>
}

0 comments on commit af236da

Please sign in to comment.