2
2
3
3
// CHECK-LABEL: func.func @entry
4
4
module {
5
- func.func @entry (%a: tensor <128 xf32 >, %b: tensor <128 xf32 >, %c: tensor <128 xf32 >) -> (tensor <128 xf32 >) attributes { llvm.emit_c_interface , onednn_graph.const_args = [0 : i32 , 1 : i32 ] } {
5
+ func.func @entry (%a: tensor <128 xf32 >, %b: tensor <128 xf32 >, %c: tensor <128 xf32 >) -> (tensor <128 xf32 >, tensor < 128 x f32 > ) attributes { llvm.emit_c_interface , onednn_graph.const_args = [0 : i32 , 1 : i32 ] } {
6
6
%c0 = arith.constant 0 : index
7
7
cpuruntime.printf " HI%zu\n " %c0 : index
8
8
%ax2 = tensor.empty () : tensor <128 xf32 >
@@ -11,39 +11,49 @@ module {
11
11
%3 = linalg.add ins (%b , %b : tensor <128 xf32 >,tensor <128 xf32 >) outs (%bx2 : tensor <128 xf32 >) -> tensor <128 xf32 >
12
12
%ax2pbx2 = tensor.empty () : tensor <128 xf32 >
13
13
%4 = linalg.add ins (%2 , %3 : tensor <128 xf32 >,tensor <128 xf32 >) outs (%ax2pbx2 : tensor <128 xf32 >) -> tensor <128 xf32 >
14
+ %ax2mbx2 = tensor.empty () : tensor <128 xf32 >
15
+ %5 = linalg.mul ins (%2 , %3 : tensor <128 xf32 >,tensor <128 xf32 >) outs (%ax2mbx2 : tensor <128 xf32 >) -> tensor <128 xf32 >
14
16
%ax2pbx2pc = tensor.empty () : tensor <128 xf32 >
15
- %d = linalg.add ins (%4 , %c : tensor <128 xf32 >,tensor <128 xf32 >) outs (%ax2pbx2pc : tensor <128 xf32 >) -> tensor <128 xf32 >
16
- return %d : tensor <128 xf32 >
17
+ %6 = linalg.add ins (%4 , %c : tensor <128 xf32 >,tensor <128 xf32 >) outs (%ax2pbx2pc : tensor <128 xf32 >) -> tensor <128 xf32 >
18
+ %ax2mbx2mc = tensor.empty () : tensor <128 xf32 >
19
+ %7 = linalg.mul ins (%5 , %c : tensor <128 xf32 >,tensor <128 xf32 >) outs (%ax2mbx2mc : tensor <128 xf32 >) -> tensor <128 xf32 >
20
+ return %6 , %7 : tensor <128 xf32 >, tensor <128 xf32 >
17
21
}
18
22
}
19
23
20
24
// CHECK: cpuruntime.printf
21
25
// CHECK: linalg.add
26
+ // CHECK: linalg.mul
22
27
// CHECK: func.func @fold
23
28
// CHECK: linalg.add
24
29
// CHECK: linalg.add
25
30
// CHECK: linalg.add
31
+ // CHECK: linalg.mul
26
32
27
33
// COM: expected output:
28
34
// COM: module {
29
35
// COM: llvm.mlir.global external constant @__num_orig_num_args(3 : i32) {addr_space = 0 : i32} : i32
30
- // COM: llvm.mlir.global external constant @__compute_args(dense<[2 , 2, 3]> : tensor<3xi32 >) {addr_space = 0 : i32} : !llvm.array<3 x i32>
31
- // COM: llvm.mlir.global external constant @__fold_args(dense<[3 , 0, 1, 3]> : tensor<4xi32 >) {addr_space = 0 : i32} : !llvm.array<4 x i32>
32
- // COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[1 , 0]> : tensor<2xi64 >) {addr_space = 0 : i32} : !llvm.array<2 x i64>
33
- // COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} {
36
+ // COM: llvm.mlir.global external constant @__compute_args(dense<[3 , 2, 3, 4 ]> : tensor<4xi32 >) {addr_space = 0 : i32} : !llvm.array<4 x i32>
37
+ // COM: llvm.mlir.global external constant @__fold_args(dense<[4 , 0, 1, 3, 4 ]> : tensor<5xi32 >) {addr_space = 0 : i32} : !llvm.array<5 x i32>
38
+ // COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[2 , 0, 1 ]> : tensor<3xi64 >) {addr_space = 0 : i32} : !llvm.array<3 x i64>
39
+ // COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<128xf32> ) -> ( tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} {
34
40
// COM: %c0 = arith.constant 0 : index
35
41
// COM: cpuruntime.printf "HI%zu\0A" %c0 : index
36
42
// COM: %0 = tensor.empty() : tensor<128xf32>
37
- // COM: %1 = linalg.add ins(%arg1, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
38
- // COM: return %1 : tensor<128xf32>
43
+ // COM: %1 = linalg.add ins(%arg2, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
44
+ // COM: %2 = tensor.empty() : tensor<128xf32>
45
+ // COM: %3 = linalg.mul ins(%arg1, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
46
+ // COM: return %1, %3 : tensor<128xf32>, tensor<128xf32>
39
47
// COM: }
40
- // COM: func.func @fold(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface} {
48
+ // COM: func.func @fold(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> ( tensor<128xf32>, tensor<128xf32>) attributes {llvm.emit_c_interface} {
41
49
// COM: %0 = tensor.empty() : tensor<128xf32>
42
50
// COM: %1 = linalg.add ins(%arg0, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
43
51
// COM: %2 = tensor.empty() : tensor<128xf32>
44
52
// COM: %3 = linalg.add ins(%arg1, %arg1 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
45
53
// COM: %4 = tensor.empty() : tensor<128xf32>
46
54
// COM: %5 = linalg.add ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%4 : tensor<128xf32>) -> tensor<128xf32>
47
- // COM: return %5 : tensor<128xf32>
55
+ // COM: %6 = tensor.empty() : tensor<128xf32>
56
+ // COM: %7 = linalg.mul ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%6 : tensor<128xf32>) -> tensor<128xf32>
57
+ // COM: return %7, %5 : tensor<128xf32>, tensor<128xf32>
48
58
// COM: }
49
59
// COM: }
0 commit comments