Skip to content

Commit e6282af

Browse files
committed
Delete some comments. Fix static multi-dim support.
1 parent 6293e1c commit e6282af

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5854,13 +5854,12 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
58545854
/*end=*/numNonzero,
58555855
/*step=*/constantOne);
58565856

5857-
// Convert flattened indices back to multi-dimensional indices
5858-
// original_shape = t.shape
5859-
// input_shape_tensor = torch.tensor(original_shape)
5857+
// TODO fix multidim dynamic support. The following code only work for
5858+
// static multidim. Convert flattened indices back to multi-dimensional
5859+
// indices original_shape = t.shape input_shape_tensor =
5860+
// torch.tensor(original_shape)
58605861
auto shapeType = Torch::ValueTensorType::get(
58615862
rewriter.getContext(), SmallVector<int64_t>{inputRank}, intType);
5862-
// Value inputShapeTensor =
5863-
// rewriter.create<Torch::Aten_ShapeAsTensorOp>(loc, shapeType, input);
58645863
SmallVector<Value> shapeValues;
58655864
for (int i = 0; i < inputRank; i++) {
58665865
auto constantI =
@@ -5869,8 +5868,6 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
58695868
/*dim=*/constantI);
58705869
shapeValues.push_back(shape);
58715870
}
5872-
// Value shape0 = rewriter.create<AtenSizeIntOp>(loc, input,
5873-
// /*dim=*/constantZero);
58745871
Value shapeTensorList = rewriter.create<Torch::PrimListConstructOp>(
58755872
loc, Torch::ListType::get(shapeValues[0].getType()), shapeValues);
58765873
Value inputShapeTensor = rewriter.create<Torch::AtenTensorOp>(
@@ -5886,7 +5883,7 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
58865883

58875884
// strides = torch.cat([strides[1:-1], torch.tensor([1])])
58885885
auto oneTensorType = ValueTensorType::get(rewriter.getContext(),
5889-
SmallVector<int64_t>{}, intType);
5886+
SmallVector<int64_t>{1}, intType);
58905887
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
58915888
loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst,
58925889
noneCst);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3020,7 +3020,6 @@
30203020
"LinalgNormKeepDimComplexModule_basic",
30213021
"LinalgVectorNormComplexModule_basic",
30223022
"LogSoftmaxBackwardModule_basic",
3023-
"MaskedScatterStaticBasic_basic",
30243023
"MaxPool1dCeilModeTrueModule_basic",
30253024
"MaxPool1dModule_basic",
30263025
"MaxPool2dCeilModeTrueModule_basic",

0 commit comments

Comments
 (0)