@@ -5854,13 +5854,12 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
5854
5854
/* end=*/ numNonzero,
5855
5855
/* step=*/ constantOne);
5856
5856
5857
- // Convert flattened indices back to multi-dimensional indices
5858
- // original_shape = t.shape
5859
- // input_shape_tensor = torch.tensor(original_shape)
5857
+ // TODO fix multidim dynamic support. The following code only work for
5858
+ // static multidim. Convert flattened indices back to multi-dimensional
5859
+ // indices original_shape = t.shape input_shape_tensor =
5860
+ // torch.tensor(original_shape)
5860
5861
auto shapeType = Torch::ValueTensorType::get (
5861
5862
rewriter.getContext (), SmallVector<int64_t >{inputRank}, intType);
5862
- // Value inputShapeTensor =
5863
- // rewriter.create<Torch::Aten_ShapeAsTensorOp>(loc, shapeType, input);
5864
5863
SmallVector<Value> shapeValues;
5865
5864
for (int i = 0 ; i < inputRank; i++) {
5866
5865
auto constantI =
@@ -5869,8 +5868,6 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
5869
5868
/* dim=*/ constantI);
5870
5869
shapeValues.push_back (shape);
5871
5870
}
5872
- // Value shape0 = rewriter.create<AtenSizeIntOp>(loc, input,
5873
- // /*dim=*/constantZero);
5874
5871
Value shapeTensorList = rewriter.create <Torch::PrimListConstructOp>(
5875
5872
loc, Torch::ListType::get (shapeValues[0 ].getType ()), shapeValues);
5876
5873
Value inputShapeTensor = rewriter.create <Torch::AtenTensorOp>(
@@ -5886,7 +5883,7 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
5886
5883
5887
5884
// strides = torch.cat([strides[1:-1], torch.tensor([1])])
5888
5885
auto oneTensorType = ValueTensorType::get (rewriter.getContext (),
5889
- SmallVector<int64_t >{}, intType);
5886
+ SmallVector<int64_t >{1 }, intType);
5890
5887
Value oneTensor = rewriter.create <AtenScalarTensorOp>(
5891
5888
loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst,
5892
5889
noneCst);
0 commit comments