Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][TORCH] Add OnnxToTorch support for BlackmanWindow function #3181

Merged
merged 7 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address Comments
  • Loading branch information
vinayakdsci committed Apr 26, 2024
commit 7b22a0c0a5c2c206a8aed4a3a5d7be5712eaa790
6 changes: 3 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2223,7 +2223,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
none, none);

// Required contants
const double piDouble = 2.0 * llvm::numbers::pi;
constexpr double pi = llvm::numbers::pi;
Value alpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
Expand All @@ -2235,10 +2235,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
Value twicePi = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), piDouble));
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
Value fourPi = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * piDouble));
rewriter.getFloatAttr(rewriter.getF64Type(), 4.0 * pi));

// Calculate the window function
Value productTimesTwoPi = rewriter.create<Torch::AtenMulScalarOp>(
Expand Down
100 changes: 50 additions & 50 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2001,32 +2001,32 @@ func.func @test_eyelike_dynamic(%arg0: !torch.vtensor<[3,?],f32>) -> !torch.vten

// CHECK-LABEL: func.func @test_blackmanwindow_symmetric
func.func @test_blackmanwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %arg0, %[[INT1_0]], %[[INT1_0]] : !torch.vtensor<[],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SUB]] : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[NONE_0:.+]] = torch.constant.none
// CHECK: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[INT0]], %[[ITEM]], %[[INT1_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: %[[ALPHA:.+]] = torch.constant.float 4.200000e-01
// CHECK: %[[BETA:.+]] = torch.constant.float 8.000000e-02
// CHECK: %[[NEGHALF:.+]] = torch.constant.float -5.000000e-01
// CHECK: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
// CHECK: %[[FOURPI:.+]] = torch.constant.float 12.566370614359172
// CHECK: %[[MULTWOPI:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[MULFOURPI:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[FOURPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[DIVFOURPI:.+]] = torch.aten.div.Tensor %[[MULFOURPI]], %[[SUB]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],si32> -> !torch.vtensor<[10],f32>
// CHECK: %[[DIVTWOPI:.+]] = torch.aten.div.Tensor %[[MULTWOPI]], %[[SUB]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],si32> -> !torch.vtensor<[10],f32>
// CHECK: %[[COS_0:.+]] = torch.aten.cos %[[DIVTWOPI]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK: %[[MULCOS_0:.+]] = torch.aten.mul.Scalar %[[COS_0]], %[[NEGHALF]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[COS_1:.+]] = torch.aten.cos %[[DIVFOURPI]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK: %[[MULCOS_1:.+]] = torch.aten.mul.Scalar %[[COS_1]], %[[BETA]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK: %[[ADDCOS:.+]] = torch.aten.add.Tensor %[[MULCOS_0]], %[[MULCOS_1]], %[[INT1_1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[ADDALPHA:.+]] = torch.aten.add.Scalar %[[ADDCOS]], %[[ALPHA]], %[[INT1_1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[INT6:.+]] = torch.constant.int 6
// CHECK: %[[NONE_1:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SUB:.+]] = torch.aten.sub.Scalar %arg0, %[[INT1_0]], %[[INT1_0]] : !torch.vtensor<[],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
// CHECK-DAG: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[ITEM:.+]] = torch.aten.item %[[SUB]] : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[NONE_0:.+]] = torch.constant.none
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[INT0]], %[[ITEM]], %[[INT1_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[ALPHA:.+]] = torch.constant.float 4.200000e-01
// CHECK-DAG: %[[BETA:.+]] = torch.constant.float 8.000000e-02
// CHECK-DAG: %[[NEGHALF:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[FOURPI:.+]] = torch.constant.float 12.566370614359172
// CHECK-DAG: %[[MULTWOPI:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[MULFOURPI:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[FOURPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[DIVFOURPI:.+]] = torch.aten.div.Tensor %[[MULFOURPI]], %[[SUB]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],si32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[DIVTWOPI:.+]] = torch.aten.div.Tensor %[[MULTWOPI]], %[[SUB]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],si32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COS_0:.+]] = torch.aten.cos %[[DIVTWOPI]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[MULCOS_0:.+]] = torch.aten.mul.Scalar %[[COS_0]], %[[NEGHALF]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COS_1:.+]] = torch.aten.cos %[[DIVFOURPI]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[MULCOS_1:.+]] = torch.aten.mul.Scalar %[[COS_1]], %[[BETA]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[ADDCOS:.+]] = torch.aten.add.Tensor %[[MULCOS_0]], %[[MULCOS_1]], %[[INT1_1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[ADDALPHA:.+]] = torch.aten.add.Scalar %[[ADDCOS]], %[[ALPHA]], %[[INT1_1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.int -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[NONE_1:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[ADDALPHA]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
Expand All @@ -2037,30 +2037,30 @@ func.func @test_blackmanwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !tor

// CHECK-LABEL: func.func @test_blackmanwindow
func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[ITEM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[NONE_0:.+]] = torch.constant.none
// CHECK: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[INT0]], %[[ITEM]], %[[INT1_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: %[[ALPHA:.+]] = torch.constant.float 4.200000e-01
// CHECK: %[[BETA:.+]] = torch.constant.float 8.000000e-02
// CHECK: %[[NEGHALF:.+]] = torch.constant.float -5.000000e-01
// CHECK: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
// CHECK: %[[FOURPI:.+]] = torch.constant.float 12.566370614359172
// CHECK: %[[MULTWOPI:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[MULFOURPI:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[FOURPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[DIVFOURPI:.+]] = torch.aten.div.Tensor %[[MULFOURPI]], %arg0 : !torch.vtensor<[10],f32>, !torch.vtensor<[],si32> -> !torch.vtensor<[10],f32>
// CHECK: %[[DIVTWOPI:.+]] = torch.aten.div.Tensor %[[MULTWOPI]], %arg0 : !torch.vtensor<[10],f32>, !torch.vtensor<[],si32> -> !torch.vtensor<[10],f32>
// CHECK: %[[COS_0:.+]] = torch.aten.cos %[[DIVTWOPI]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK: %[[MULCOS_0:.+]] = torch.aten.mul.Scalar %[[COS_0]], %[[NEGHALF]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[COS_1:.+]] = torch.aten.cos %[[DIVFOURPI]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK: %[[MULCOS_1:.+]] = torch.aten.mul.Scalar %[[COS_1]], %[[BETA]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK: %[[ADDCOS:.+]] = torch.aten.add.Tensor %[[MULCOS_0]], %[[MULCOS_1]], %[[INT1_1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[ADDALPHA:.+]] = torch.aten.add.Scalar %[[ADDCOS]], %[[ALPHA]], %[[INT1_1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[INT6:.+]] = torch.constant.int 6
// CHECK: %[[NONE_1:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[ITEM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[NONE_0:.+]] = torch.constant.none
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[INT0]], %[[ITEM]], %[[INT1_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[ALPHA:.+]] = torch.constant.float 4.200000e-01
// CHECK-DAG: %[[BETA:.+]] = torch.constant.float 8.000000e-02
// CHECK-DAG: %[[NEGHALF:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[FOURPI:.+]] = torch.constant.float 12.566370614359172
// CHECK-DAG: %[[MULTWOPI:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[MULFOURPI:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[FOURPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[DIVFOURPI:.+]] = torch.aten.div.Tensor %[[MULFOURPI]], %arg0 : !torch.vtensor<[10],f32>, !torch.vtensor<[],si32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[DIVTWOPI:.+]] = torch.aten.div.Tensor %[[MULTWOPI]], %arg0 : !torch.vtensor<[10],f32>, !torch.vtensor<[],si32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COS_0:.+]] = torch.aten.cos %[[DIVTWOPI]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[MULCOS_0:.+]] = torch.aten.mul.Scalar %[[COS_0]], %[[NEGHALF]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COS_1:.+]] = torch.aten.cos %[[DIVFOURPI]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[MULCOS_1:.+]] = torch.aten.mul.Scalar %[[COS_1]], %[[BETA]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[ADDCOS:.+]] = torch.aten.add.Tensor %[[MULCOS_0]], %[[MULCOS_1]], %[[INT1_1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[ADDALPHA:.+]] = torch.aten.add.Scalar %[[ADDCOS]], %[[ALPHA]], %[[INT1_1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.int -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[NONE_1:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[ADDALPHA]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_1]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
Expand Down