Skip to content

Commit 192d332

Browse files
authored
[mlir][nvgpu] Add predicate argument to NVGPU Ops (#69322)
1 parent fea55db commit 192d332

File tree

4 files changed

+88
-19
lines changed

4 files changed

+88
-19
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,8 @@ def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier.init", []> {
522522
nvgpu.mbarrier.init %barrier, %num_threads : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
523523
```
524524
}];
525-
let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$count, Index:$mbarId);
526-
let assemblyFormat = "$barriers `[` $mbarId `]` `,` $count attr-dict `:` type($barriers)";
525+
let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$count, Index:$mbarId, Optional<I1>:$predicate);
526+
let assemblyFormat = "$barriers `[` $mbarId `]` `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)";
527527
}
528528

529529
def NVGPU_MBarrierTestWaitOp : NVGPU_Op<"mbarrier.test.wait", []> {
@@ -597,8 +597,8 @@ def NVGPU_MBarrierArriveExpectTxOp : NVGPU_Op<"mbarrier.arrive.expect_tx", []> {
597597
nvgpu.mbarrier.arrive.expect_tx %barrier, %ic0 : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
598598
```
599599
}];
600-
let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$txcount, Index:$mbarId);
601-
let assemblyFormat = "$barriers `[` $mbarId `]` `,` $txcount attr-dict `:` type($barriers)";
600+
let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$txcount, Index:$mbarId, Optional<I1>:$predicate);
601+
let assemblyFormat = "$barriers `[` $mbarId `]` `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)";
602602
}
603603

604604
def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
@@ -627,11 +627,11 @@ def NVGPU_TmaPrefetchOp : NVGPU_Op<"tma.prefetch.descriptor", []> {
627627
}];
628628
let arguments = (ins NVGPU_TensorMapDescriptor:$tensorMapDescriptor, Optional<I1>:$predicate);
629629
let assemblyFormat = [{
630-
$tensorMapDescriptor (`,` $predicate^)? attr-dict `:` type($tensorMapDescriptor)
630+
$tensorMapDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type($tensorMapDescriptor)
631631
}];
632632
}
633633

634-
def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", []> {
634+
def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]> {
635635
let summary = "TMA asynchronous load";
636636
let description = [{
637637
The Op loads a tile memory region from global memory to shared memory by
@@ -646,10 +646,14 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", []> {
646646
NVGPU_MBarrierGroup:$barriers,
647647
NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
648648
Variadic<Index>:$coordinates,
649-
Index:$mbarId);
649+
Index:$mbarId,
650+
Optional<I1>:$predicate);
650651
let assemblyFormat = [{
651-
$tensorMapDescriptor `[` $coordinates `]` `,` $barriers `[` $mbarId `]` `to` $dst
652-
attr-dict `:` type($tensorMapDescriptor) `,` type($barriers) `->` type($dst)
652+
$tensorMapDescriptor `[` $coordinates `]` `,` $barriers `[` $mbarId `]`
653+
`to` $dst
654+
(`,` `predicate` `=` $predicate^)?
655+
attr-dict `:` type($tensorMapDescriptor) `,` type($barriers)
656+
`->` type($dst)
653657
}];
654658
let hasVerifier = 1;
655659

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -830,11 +830,11 @@ struct NVGPUMBarrierInitLowering
830830
adaptor.getMbarId(), rewriter);
831831
Value count = truncToI32(b, adaptor.getCount());
832832
if (isMbarrierShared(mbarrierType)) {
833-
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
834-
count, Value());
833+
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
834+
op, barrier, count, adaptor.getPredicate());
835835
} else {
836836
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
837-
Value());
837+
adaptor.getPredicate());
838838
}
839839
return success();
840840
}
@@ -929,12 +929,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
929929

930930
if (isMbarrierShared(op.getBarriers().getType())) {
931931
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
932-
op, barrier, txcount, Value());
932+
op, barrier, txcount, adaptor.getPredicate());
933933
return success();
934934
}
935935

936936
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
937-
op, barrier, txcount, Value());
937+
op, barrier, txcount, adaptor.getPredicate());
938938
return success();
939939
}
940940
};
@@ -985,7 +985,8 @@ struct NVGPUTmaAsyncLoadOpLowering
985985
}
986986

987987
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
988-
op, dest, adaptor.getTensorMapDescriptor(), barrier, coords, Value());
988+
op, dest, adaptor.getTensorMapDescriptor(), barrier, coords,
989+
adaptor.getPredicate());
989990
return success();
990991
}
991992
};

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
922922
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
923923
rewriter.create<nvgpu::MBarrierInitOp>(
924924
loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
925-
zero);
925+
zero, Value());
926926
rewriter.create<gpu::BarrierOp>(loc);
927927
return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
928928
}
@@ -964,7 +964,8 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad(
964964
MLIRContext *ctx = rewriter.getContext();
965965
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
966966
Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
967-
loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero);
967+
loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
968+
Value());
968969
loadOps.push_back(loadOp);
969970
auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
970971
SmallVector<AffineExpr> symbols(mixedSizes.size());
@@ -989,7 +990,8 @@ void HopperBuilder::buildBarrierArriveTx(
989990
affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
990991
Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
991992
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
992-
rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero);
993+
rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
994+
Value());
993995
}
994996

995997
void HopperBuilder::buildTryWaitParity(

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,42 @@ func.func @mbarrier_txcount() {
600600
func.return
601601
}
602602

603+
// CHECK-LABEL: func @mbarrier_txcount_pred
604+
func.func @mbarrier_txcount_pred() {
605+
%mine = arith.constant 1 : index
606+
// CHECK: %[[c0:.+]] = arith.constant 0 : index
607+
// CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
608+
// CHECK: %[[S2:.+]] = gpu.thread_id x
609+
// CHECK: %[[P:.+]] = arith.cmpi eq, %[[S2]], %[[c0]] : index
610+
%c0 = arith.constant 0 : index
611+
%tidx = gpu.thread_id x
612+
%pred = arith.cmpi eq, %tidx, %c0 : index
613+
614+
// CHECK: %[[barMemref:.+]] = memref.get_global @__mbarrier{{.*}} : memref<1xi64, 3>
615+
%barrier = nvgpu.mbarrier.create -> !barrierType
616+
617+
// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
618+
// CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
619+
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
620+
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]], {{.*}}, predicate = %[[P]]
621+
nvgpu.mbarrier.init %barrier[%c0], %mine, predicate = %pred : !barrierType
622+
623+
%txcount = arith.constant 256 : index
624+
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
625+
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
626+
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
627+
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
628+
629+
%phase = arith.constant 0 : index
630+
%ticks = arith.constant 10000000 : index
631+
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
632+
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
633+
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
634+
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase, %ticks : !barrierType
635+
636+
func.return
637+
}
638+
603639
// CHECK-LABEL: func @async_tma_load
604640
!tensorMap1d = !nvgpu.tensormap.descriptor<tensor = memref<128xf32,3>, swizzle=none, l2promo = none, oob = nan, interleave = none>
605641
!tensorMap2d = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
@@ -630,6 +666,32 @@ func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d
630666
func.return
631667
}
632668

669+
// CHECK-LABEL: func @async_tma_load_pred
670+
func.func @async_tma_load_pred(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
671+
%buffer1d: memref<128xf32,3>,
672+
%buffer2d: memref<32x32xf32,3>,
673+
%buffer3d: memref<2x32x32xf32,3>,
674+
%buffer4d: memref<2x2x32x32xf32,3>,
675+
%buffer5d: memref<2x2x2x32x32xf32,3>,
676+
%mbarrier: !mbarrier,
677+
%p: i1) {
678+
%c0 = arith.constant 0 : index
679+
%crd0 = arith.constant 0 : index
680+
%crd1 = arith.constant 0 : index
681+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
682+
nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier[%c0] to %buffer1d, predicate = %p : !tensorMap1d, !mbarrier -> memref<128xf32,3>
683+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
684+
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d, predicate = %p : !tensorMap2d, !mbarrier -> memref<32x32xf32,3>
685+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
686+
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d, predicate = %p : !tensorMap3d, !mbarrier -> memref<2x32x32xf32,3>
687+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
688+
nvgpu.tma.async.load %tensorMap4d[%crd0, %crd1, %crd1, %crd0], %mbarrier[%c0] to %buffer4d, predicate = %p : !tensorMap4d, !mbarrier -> memref<2x2x32x32xf32,3>
689+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
690+
nvgpu.tma.async.load %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], %mbarrier[%c0] to %buffer5d, predicate = %p : !tensorMap5d, !mbarrier -> memref<2x2x2x32x32xf32,3>
691+
func.return
692+
}
693+
694+
633695
func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
634696
%crd0 = arith.constant 64 : index
635697
%crd1 = arith.constant 128 : index
@@ -650,7 +712,7 @@ func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
650712
// CHECK: nvvm.prefetch.tensormap %[[S0]] : !llvm.ptr
651713
nvgpu.tma.prefetch.descriptor %tensorMap1d: !tensorMap1d
652714
// CHECK: nvvm.prefetch.tensormap %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
653-
nvgpu.tma.prefetch.descriptor %tensorMap1d, %p: !tensorMap1d
715+
nvgpu.tma.prefetch.descriptor %tensorMap1d, predicate = %p: !tensorMap1d
654716
func.return
655717
}
656718

0 commit comments

Comments
 (0)