Skip to content

Commit

Permalink
Fix onnx.Gather lowering with dynamic shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Aug 29, 2024
1 parent 98e0802 commit 38c3322
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
13 changes: 10 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
indicesCt = Torch::kUnknownSize;
break;
}

indicesCt *= sz;
}

Expand Down Expand Up @@ -1976,8 +1975,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return success();
}

rewriter.replaceOpWithNewOp<Torch::AtenSqueezeOp>(binder.op, resultType,
gather);
// 0 indicesRank will always squeeze the axis dim
// Use PrimsSqueezeOp for the case of result with dynamic shape
SmallVector<Value> dimList({index});
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
dimList);
rewriter.replaceOpWithNewOp<Torch::PrimsSqueezeOp>(
binder.op, resultType, gather, dimValueList);
return success();
});
patterns.onOp(
Expand Down
3 changes: 2 additions & 1 deletion test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.
// CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1
// CHECK: %[[FLAT:.+]] = torch.aten.unsqueeze %[[SEL]], %[[ZERO]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]]
// CHECK: %[[RES:.+]] = torch.aten.squeeze %[[ISEL]] : !torch.vtensor<[1,4,5],f32> -> !torch.vtensor<[4,5],f32>
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[AXIS]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[RES:.+]] = torch.prims.squeeze %[[ISEL]], %[[DIMS]] : !torch.vtensor<[1,4,5],f32>, !torch.list<int> -> !torch.vtensor<[4,5],f32>
// CHECK: return %[[RES]]
%0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32>
return %0 : !torch.vtensor<[4,5],f32>
Expand Down

0 comments on commit 38c3322

Please sign in to comment.