@@ -600,6 +600,42 @@ func.func @mbarrier_txcount() {
600
600
func.return
601
601
}
602
602
603
+ // CHECK-LABEL: func @mbarrier_txcount_pred
604
+ func.func @mbarrier_txcount_pred () {
605
+ %mine = arith.constant 1 : index
606
+ // CHECK: %[[c0:.+]] = arith.constant 0 : index
607
+ // CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
608
+ // CHECK: %[[S2:.+]] = gpu.thread_id x
609
+ // CHECK: %[[P:.+]] = arith.cmpi eq, %[[S2]], %[[c0]] : index
610
+ %c0 = arith.constant 0 : index
611
+ %tidx = gpu.thread_id x
612
+ %pred = arith.cmpi eq , %tidx , %c0 : index
613
+
614
+ // CHECK: %[[barMemref:.+]] = memref.get_global @__mbarrier{{.*}} : memref<1xi64, 3>
615
+ %barrier = nvgpu.mbarrier.create -> !barrierType
616
+
617
+ // CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
618
+ // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
619
+ // CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
620
+ // CHECK: nvvm.mbarrier.init.shared %[[barPtr]], {{.*}}, predicate = %[[P]]
621
+ nvgpu.mbarrier.init %barrier [%c0 ], %mine , predicate = %pred : !barrierType
622
+
623
+ %txcount = arith.constant 256 : index
624
+ // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
625
+ // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
626
+ // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
627
+ nvgpu.mbarrier.arrive.expect_tx %barrier [%c0 ], %txcount , predicate = %pred : !barrierType
628
+
629
+ %phase = arith.constant 0 : index
630
+ %ticks = arith.constant 10000000 : index
631
+ // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
632
+ // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
633
+ // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
634
+ nvgpu.mbarrier.try_wait.parity %barrier [%c0 ], %phase , %ticks : !barrierType
635
+
636
+ func.return
637
+ }
638
+
603
639
// CHECK-LABEL: func @async_tma_load
604
640
!tensorMap1d = !nvgpu.tensormap.descriptor <tensor = memref <128 xf32 ,3 >, swizzle =none , l2promo = none , oob = nan , interleave = none >
605
641
!tensorMap2d = !nvgpu.tensormap.descriptor <tensor = memref <32 x32 xf32 ,3 >, swizzle =swizzle_32b , l2promo = none , oob = zero , interleave = none >
@@ -630,6 +666,32 @@ func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d
630
666
func.return
631
667
}
632
668
669
+ // CHECK-LABEL: func @async_tma_load_pred
670
+ func.func @async_tma_load_pred (%tensorMap1d: !tensorMap1d , %tensorMap2d: !tensorMap2d , %tensorMap3d: !tensorMap3d , %tensorMap4d: !tensorMap4d , %tensorMap5d: !tensorMap5d ,
671
+ %buffer1d: memref <128 xf32 ,3 >,
672
+ %buffer2d: memref <32 x32 xf32 ,3 >,
673
+ %buffer3d: memref <2 x32 x32 xf32 ,3 >,
674
+ %buffer4d: memref <2 x2 x32 x32 xf32 ,3 >,
675
+ %buffer5d: memref <2 x2 x2 x32 x32 xf32 ,3 >,
676
+ %mbarrier: !mbarrier ,
677
+ %p: i1 ) {
678
+ %c0 = arith.constant 0 : index
679
+ %crd0 = arith.constant 0 : index
680
+ %crd1 = arith.constant 0 : index
681
+ // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
682
+ nvgpu.tma.async.load %tensorMap1d [%crd0 ], %mbarrier [%c0 ] to %buffer1d , predicate = %p : !tensorMap1d , !mbarrier -> memref <128 xf32 ,3 >
683
+ // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
684
+ nvgpu.tma.async.load %tensorMap2d [%crd0 , %crd1 ], %mbarrier [%c0 ] to %buffer2d , predicate = %p : !tensorMap2d , !mbarrier -> memref <32 x32 xf32 ,3 >
685
+ // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
686
+ nvgpu.tma.async.load %tensorMap3d [%crd0 , %crd1 , %crd0 ], %mbarrier [%c0 ] to %buffer3d , predicate = %p : !tensorMap3d , !mbarrier -> memref <2 x32 x32 xf32 ,3 >
687
+ // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
688
+ nvgpu.tma.async.load %tensorMap4d [%crd0 , %crd1 , %crd1 , %crd0 ], %mbarrier [%c0 ] to %buffer4d , predicate = %p : !tensorMap4d , !mbarrier -> memref <2 x2 x32 x32 xf32 ,3 >
689
+ // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
690
+ nvgpu.tma.async.load %tensorMap5d [%crd0 , %crd1 , %crd1 , %crd0 , %crd0 ], %mbarrier [%c0 ] to %buffer5d , predicate = %p : !tensorMap5d , !mbarrier -> memref <2 x2 x2 x32 x32 xf32 ,3 >
691
+ func.return
692
+ }
693
+
694
+
633
695
func.func @create_tensor_map (%devicePtr2d : memref <64 x128 xf32 >, %devicePtr1d : memref <128 xf32 >) {
634
696
%crd0 = arith.constant 64 : index
635
697
%crd1 = arith.constant 128 : index
@@ -650,7 +712,7 @@ func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
650
712
// CHECK: nvvm.prefetch.tensormap %[[S0]] : !llvm.ptr
651
713
nvgpu.tma.prefetch.descriptor %tensorMap1d: !tensorMap1d
652
714
// CHECK: nvvm.prefetch.tensormap %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
653
- nvgpu.tma.prefetch.descriptor %tensorMap1d , %p: !tensorMap1d
715
+ nvgpu.tma.prefetch.descriptor %tensorMap1d , predicate = %p: !tensorMap1d
654
716
func.return
655
717
}
656
718
0 commit comments