@@ -1450,6 +1450,33 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
14501450
14511451// -----
14521452
1453+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
1454+ func.func @push_unpack_in_padded_domain_multiple_inputs (%arg0: tensor <1 x4 x16 x16 xf32 >, %arg1: tensor <8 x64 xf32 >, %arg2: tensor <8 x64 xf32 >) -> tensor <8 x64 xf32 > {
1455+ %0 = tensor.empty () : tensor <8 x64 xf32 >
1456+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 16 ] into %0 : tensor <1 x4 x16 x16 xf32 > -> tensor <8 x64 xf32 >
1457+ %1 = linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" ]} ins (%arg1 , %unpack : tensor <8 x64 xf32 >, tensor <8 x64 xf32 >) outs (%arg2 : tensor <8 x64 xf32 >) {
1458+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
1459+ %2 = arith.addf %in , %in_0 : f32
1460+ linalg.yield %2 : f32
1461+ } -> tensor <8 x64 xf32 >
1462+ return %1 : tensor <8 x64 xf32 >
1463+ }
1464+ // CHECK-LABEL: func.func @push_unpack_in_padded_domain_multiple_inputs
1465+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1466+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1467+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1468+ // CHECK-DAG: %[[POISON:.+]] = ub.poison : f32
1469+ // CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]] padding_value(%[[POISON]] : f32)
1470+ // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [16, 16]
1471+ // CHECK: %[[ELEM:.+]] = linalg.generic
1472+ // CHECK: ins(%[[PACK]], %[[ARG0]]
1473+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ELEM]]
1474+ // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [16, 16]
1475+ // CHECK-SAME: into %[[ARG2]]
1476+ // CHECK: return %[[UNPACK]]
1477+
1478+ // -----
1479+
14531480module {
14541481 func.func @push_extract_through_generic (%arg0: tensor <128 x7 x128 xf32 >, %arg1: tensor <?x5 x3 x128 xf32 >, %arg2: tensor <?x5 x128 xbf16 >, %arg3: index ) -> tensor <?x5 x128 xbf16 > {
14551482 %extracted_slice = tensor.extract_slice %arg0 [0 , 0 , %arg3 ] [128 , 7 , %arg3 ] [1 , 1 , 1 ] : tensor <128 x7 x128 xf32 > to tensor <128 x7 x?xf32 >
@@ -1473,7 +1500,7 @@ module {
14731500// CHECK: } : tensor<?x5x3x128xf32> to tensor<?x5x3x128xf32>
14741501// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
14751502// CHECK: %[[GENERIC:.+]] = linalg.generic
1476- // CHECK-SAME: ins(%[[ARG0]], %[[PADDED]]
1503+ // CHECK-SAME: ins(%[[ARG0]], %[[PADDED]]
14771504// CHECK-SAME: outs(%[[EMPTY]]
14781505// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
14791506// CHECK: return %[[EXTRACT]]
@@ -1492,7 +1519,7 @@ func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32
14921519
14931520// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
14941521// CHECK: %[[GENERIC:.+]] = linalg.generic
1495- // CHECK: return %[[GENERIC]]
1522+ // CHECK: return %[[GENERIC]]
14961523
14971524// -----
14981525
@@ -1508,7 +1535,7 @@ func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32
15081535
15091536// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
15101537// CHECK: %[[GENERIC:.+]] = linalg.generic
1511- // CHECK: return %[[GENERIC]]
1538+ // CHECK: return %[[GENERIC]]
15121539
15131540// -----
15141541
@@ -1575,7 +1602,7 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
15751602
15761603// CHECK-LABEL: func.func @push_extract_through_generic_rank0_operand
15771604// CHECK: %[[GENERIC:.+]] = linalg.generic
1578- // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
1605+ // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
15791606// CHECK: return %[[EXTRACT]]
15801607
15811608// -----
0 commit comments