Skip to content

Commit

Permalink
[TorchToLinalg][test] Add test for ConvertAtenConvolutionOp
Browse files Browse the repository at this point in the history
This patch add a test for 638ef14, which use `linalg.broadcast`
instead of `generic` for convolution bias.
  • Loading branch information
SenHaiG authored and CoTinker committed Aug 30, 2024
1 parent fd759e4 commit 2017c70
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions test/Conversion/TorchToLinalg/convolution.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,29 @@ func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtens
%11 = torch.aten.dequantize.tensor %10 : !torch.vtensor<[?,?,?,?],!torch.qint32> -> !torch.vtensor<[?,?,?,?],f32>
return %11 : !torch.vtensor<[?,?,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @conv_broadcast(
// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,80,3000],f32>,
// CHECK-SAME: %[[arg1:.*]]: !torch.vtensor<[1024,80,3],f32>,
// CHECK-SAME: %[[arg2:.*]]: !torch.vtensor<[1024],f32>) -> !torch.vtensor<[1,1024,3000],f32> {
// CHECK: %[[c0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %[[arg0]] : !torch.vtensor<[1,80,3000],f32> -> tensor<1x80x3000xf32>
// CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %[[arg1]] : !torch.vtensor<[1024,80,3],f32> -> tensor<1024x80x3xf32>
// CHECK: %[[bias:.*]] = torch_c.to_builtin_tensor %[[arg2]] : !torch.vtensor<[1024],f32> -> tensor<1024xf32>
// CHECK: %[[padInput:.*]] = tensor.pad %[[input]] low[0, 0, 1] high[0, 0, 1]
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1024x3000xf32>
// CHECK: %[[broadcastBias:.*]] = linalg.broadcast ins(%[[bias]] : tensor<1024xf32>) outs(%[[EMPTY]] : tensor<1x1024x3000xf32>) dimensions = [0, 2]
// CHECK: %[[conv:.*]] = linalg.conv_1d_ncw_fcw {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
// CHECK-SAME: ins(%[[padInput:.*]], %[[weight]] : tensor<1x80x3002xf32>, tensor<1024x80x3xf32>)
// CHECK-SAME: outs(%[[broadcastBias]] : tensor<1x1024x3000xf32>) -> tensor<1x1024x3000xf32>
func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch.vtensor<[1024,80,3],f32>, %arg2: !torch.vtensor<[1024],f32>) -> !torch.vtensor<[1,1024,3000],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%false = torch.constant.bool false
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
%2 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[1,80,3000],f32>, !torch.vtensor<[1024,80,3],f32>, !torch.vtensor<[1024],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1024,3000],f32>
return %2 : !torch.vtensor<[1,1024,3000],f32>
}

0 comments on commit 2017c70

Please sign in to comment.