Skip to content

Commit 3f34e97

Browse files
committed
Add a test
1 parent d7663a5 commit 3f34e97

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

test/gc/Transforms/test_constant_tensor_folding-1.mlir

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// CHECK-LABEL: func.func @entry
44
module {
5-
func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32] } {
5+
func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32] } {
66
%c0 = arith.constant 0 : index
77
cpuruntime.printf "HI%zu\n" %c0 : index
88
%ax2 = tensor.empty() : tensor<128xf32>
@@ -11,39 +11,49 @@ module {
1111
%3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%bx2 : tensor<128xf32>) -> tensor<128xf32>
1212
%ax2pbx2 = tensor.empty() : tensor<128xf32>
1313
%4 = linalg.add ins(%2, %3 : tensor<128xf32>,tensor<128xf32>) outs(%ax2pbx2 : tensor<128xf32>) -> tensor<128xf32>
14+
%ax2mbx2 = tensor.empty() : tensor<128xf32>
15+
%5 = linalg.mul ins(%2, %3 : tensor<128xf32>,tensor<128xf32>) outs(%ax2mbx2 : tensor<128xf32>) -> tensor<128xf32>
1416
%ax2pbx2pc = tensor.empty() : tensor<128xf32>
15-
%d = linalg.add ins(%4, %c : tensor<128xf32>,tensor<128xf32>) outs(%ax2pbx2pc : tensor<128xf32>) -> tensor<128xf32>
16-
return %d : tensor<128xf32>
17+
%6 = linalg.add ins(%4, %c : tensor<128xf32>,tensor<128xf32>) outs(%ax2pbx2pc : tensor<128xf32>) -> tensor<128xf32>
18+
%ax2mbx2mc = tensor.empty() : tensor<128xf32>
19+
%7 = linalg.mul ins(%5, %c : tensor<128xf32>,tensor<128xf32>) outs(%ax2mbx2mc : tensor<128xf32>) -> tensor<128xf32>
20+
return %6, %7 : tensor<128xf32>, tensor<128xf32>
1721
}
1822
}
1923

2024
// CHECK: cpuruntime.printf
2125
// CHECK: linalg.add
26+
// CHECK: linalg.mul
2227
// CHECK: func.func @fold
2328
// CHECK: linalg.add
2429
// CHECK: linalg.add
2530
// CHECK: linalg.add
31+
// CHECK: linalg.mul
2632

2733
// COM: expected output:
2834
// COM: module {
2935
// COM: llvm.mlir.global external constant @__num_orig_num_args(3 : i32) {addr_space = 0 : i32} : i32
30-
// COM: llvm.mlir.global external constant @__compute_args(dense<[2, 2, 3]> : tensor<3xi32>) {addr_space = 0 : i32} : !llvm.array<3 x i32>
31-
// COM: llvm.mlir.global external constant @__fold_args(dense<[3, 0, 1, 3]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32>
32-
// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[1, 0]> : tensor<2xi64>) {addr_space = 0 : i32} : !llvm.array<2 x i64>
33-
// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} {
36+
// COM: llvm.mlir.global external constant @__compute_args(dense<[3, 2, 3, 4]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32>
37+
// COM: llvm.mlir.global external constant @__fold_args(dense<[4, 0, 1, 3, 4]> : tensor<5xi32>) {addr_space = 0 : i32} : !llvm.array<5 x i32>
38+
// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[2, 0, 1]> : tensor<3xi64>) {addr_space = 0 : i32} : !llvm.array<3 x i64>
39+
// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} {
3440
// COM: %c0 = arith.constant 0 : index
3541
// COM: cpuruntime.printf "HI%zu\0A" %c0 : index
3642
// COM: %0 = tensor.empty() : tensor<128xf32>
37-
// COM: %1 = linalg.add ins(%arg1, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
38-
// COM: return %1 : tensor<128xf32>
43+
// COM: %1 = linalg.add ins(%arg2, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
44+
// COM: %2 = tensor.empty() : tensor<128xf32>
45+
// COM: %3 = linalg.mul ins(%arg1, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
46+
// COM: return %1, %3 : tensor<128xf32>, tensor<128xf32>
3947
// COM: }
40-
// COM: func.func @fold(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface} {
48+
// COM: func.func @fold(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface} {
4149
// COM: %0 = tensor.empty() : tensor<128xf32>
4250
// COM: %1 = linalg.add ins(%arg0, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
4351
// COM: %2 = tensor.empty() : tensor<128xf32>
4452
// COM: %3 = linalg.add ins(%arg1, %arg1 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
4553
// COM: %4 = tensor.empty() : tensor<128xf32>
4654
// COM: %5 = linalg.add ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%4 : tensor<128xf32>) -> tensor<128xf32>
47-
// COM: return %5 : tensor<128xf32>
55+
// COM: %6 = tensor.empty() : tensor<128xf32>
56+
// COM: %7 = linalg.mul ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%6 : tensor<128xf32>) -> tensor<128xf32>
57+
// COM: return %7, %5 : tensor<128xf32>, tensor<128xf32>
4858
// COM: }
4959
// COM: }

0 commit comments

Comments
 (0)