@@ -305,6 +305,86 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
305
305
}
306
306
307
307
// -----
308
+
309
+ // CHECK: func @fold_linalg_index_tensor_static
310
+ func.func @fold_linalg_index_tensor_static (%0: tensor <4 x16 xi32 >, %1: tensor <1 x16 xi32 >,
311
+ %2: tensor <4 x1 xi32 >) -> tensor <4 x1 xi32 > {
312
+ // CHECK-NEXT: linalg.generic
313
+ // CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
314
+ // CHECK-NOT: linalg.index 1
315
+ // CHECK: %[[IDX_2:.+]] = linalg.index 2 : index
316
+ // CHECK: %[[ADD:.+]] = arith.addi %[[IDX_0]], %[[IDX_2]]
317
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]]
318
+ // CHECK: linalg.yield %[[CAST]]
319
+ %res = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>,
320
+ affine_map <(d0 , d1 , d2 ) -> (d1 , d2 )>,
321
+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>],
322
+ iterator_types = [" parallel" , " parallel" , " reduction" ]}
323
+ ins (%0 , %1 : tensor <4 x16 xi32 >, tensor <1 x16 xi32 >)
324
+ outs (%2 : tensor <4 x1 xi32 >) {
325
+ ^bb0 (%lhs: i32 , %rhs: i32 , %out: i32 ):
326
+ %idx0 = linalg.index 0 : index
327
+ %idx1 = linalg.index 1 : index
328
+ %idx2 = linalg.index 2 : index
329
+ %add0 = arith.addi %idx0 , %idx1 : index
330
+ %add1 = arith.addi %add0 , %idx2 : index
331
+ %int = arith.index_cast %add1 : index to i32
332
+ linalg.yield %int : i32
333
+ } -> tensor <4 x1 xi32 >
334
+ return %res : tensor <4 x1 xi32 >
335
+ }
336
+
337
+ // -----
338
+
339
+ // CHECK: func @fold_linalg_index_tensor_dynamic
340
+ func.func @fold_linalg_index_tensor_dynamic (%0: tensor <?x1 xi32 >,
341
+ %1: tensor <?x1 xi32 >) -> tensor <?x1 xi32 > {
342
+ // CHECK-NEXT: linalg.generic
343
+ // CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
344
+ // CHECK-NOT: linalg.index 1
345
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_0]]
346
+ // CHECK: linalg.yield %[[CAST]]
347
+ %res = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>,
348
+ affine_map <(d0 , d1 ) -> (d1 , d1 )>],
349
+ iterator_types = [" parallel" , " parallel" ]}
350
+ ins (%0 : tensor <?x1 xi32 >)
351
+ outs (%1 : tensor <?x1 xi32 >) {
352
+ ^bb0 (%lhs: i32 , %out: i32 ):
353
+ %idx0 = linalg.index 0 : index
354
+ %idx1 = linalg.index 1 : index
355
+ %add = arith.addi %idx0 , %idx1 : index
356
+ %int = arith.index_cast %add : index to i32
357
+ linalg.yield %int : i32
358
+ } -> tensor <?x1 xi32 >
359
+ return %res : tensor <?x1 xi32 >
360
+ }
361
+
362
+ // -----
363
+
364
+ // CHECK: func @fold_linalg_index_memref
365
+ func.func @fold_linalg_index_memref (%0: memref <1 x?xi32 >, %1: memref <1 x?xi32 >) {
366
+ // CHECK-NEXT: linalg.generic
367
+ // CHECK-NOT: linalg.index 0
368
+ // CHECK: %[[IDX_1:.+]] = linalg.index 1 : index
369
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_1]]
370
+ // CHECK: linalg.yield %[[CAST]]
371
+ linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>,
372
+ affine_map <(d0 , d1 ) -> (d1 , d1 )>],
373
+ iterator_types = [" parallel" , " parallel" ]}
374
+ ins (%0 : memref <1 x?xi32 >)
375
+ outs (%1 : memref <1 x?xi32 >) {
376
+ ^bb0 (%lhs: i32 , %out: i32 ):
377
+ %idx0 = linalg.index 0 : index
378
+ %idx1 = linalg.index 1 : index
379
+ %add = arith.addi %idx0 , %idx1 : index
380
+ %int = arith.index_cast %add : index to i32
381
+ linalg.yield %int : i32
382
+ }
383
+ return
384
+ }
385
+
386
+ // -----
387
+
308
388
// CHECK-LABEL: func @fold_fill_reshape()
309
389
func.func @fold_fill_reshape () -> tensor <6 x4 xf32 > {
310
390
%zero = arith.constant 0.0 : f32
0 commit comments