Skip to content

Commit b5d694b

Browse files
authored
[mlir][nvvm] Introduce nvvm.barrier OP (#81487)
This PR that introduces the `nvvm.barrier` OP to the NVVM dialect. Currently, NVVM only supports the `nvvm.barrier0`, which synchronizes all threads using barrier resource 0. The new `nvvm.barrier` has two essential arguments: the barrier resource and the number of threads. This added flexibility allows for selective synchronization of threads within a CTA, aligning with the capabilities provided by LLVM intrinsics or the PTX model. I think we can deprecate `nvvm.barrier0` in favor of the more generic `nvvm.barrier`. ``` // Equivalent to nvvm.barrier0 (or __syncthreads() in CUDA) nvvm.barrier // Synchronize all threads using the 3rd barrier resource. nvvm.barrier id = 3 // Synchronize %numberOfThreads threads using the 3rd barrier resource. nvvm.barrier id = 3 number_of_threads = %numberOfThreads ```
1 parent 86ce491 commit b5d694b

File tree

4 files changed

+57
-0
lines changed

4 files changed

+57
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,25 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
390390
let assemblyFormat = "attr-dict";
391391
}
392392

393+
def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
394+
let arguments = (ins
395+
Optional<I32>:$barrierId,
396+
Optional<I32>:$numberOfThreads);
397+
string llvmBuilder = [{
398+
if ($numberOfThreads && $barrierId) {
399+
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier,
400+
{$barrierId, $numberOfThreads});
401+
} else if($barrierId) {
402+
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_n,
403+
{$barrierId});
404+
} else {
405+
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
406+
}
407+
}];
408+
let hasVerifier = 1;
409+
let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
410+
}
411+
393412
def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> {
394413
let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
395414

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,13 @@ LogicalResult NVVM::SetMaxRegisterOp::verify() {
10221022
return success();
10231023
}
10241024

1025+
LogicalResult NVVM::BarrierOp::verify() {
1026+
if (getNumberOfThreads() && !getBarrierId())
1027+
return emitOpError(
1028+
"barrier id is missing, it should be set between 0 to 15");
1029+
return success();
1030+
}
1031+
10251032
//===----------------------------------------------------------------------===//
10261033
// NVVMDialect initialization, type parsing, and registration.
10271034
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,18 @@ func.func @llvm_nvvm_barrier0() {
4343
llvm.return
4444
}
4545

46+
// CHECK-LABEL: @llvm_nvvm_barrier
47+
// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
48+
llvm.func @llvm_nvvm_barrier(%barId : i32, %numberOfThreads : i32) {
49+
// CHECK: nvvm.barrier
50+
nvvm.barrier
51+
// CHECK: nvvm.barrier id = %[[barId]]
52+
nvvm.barrier id = %barId
53+
// CHECK: nvvm.barrier id = %[[barId]] number_of_threads = %[[numberOfThreads]]
54+
nvvm.barrier id = %barId number_of_threads = %numberOfThreads
55+
llvm.return
56+
}
57+
4658
// CHECK-LABEL: @llvm_nvvm_cluster_arrive
4759
func.func @llvm_nvvm_cluster_arrive() {
4860
// CHECK: nvvm.cluster.arrive

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,18 @@ llvm.func @llvm_nvvm_barrier0() {
8080
llvm.return
8181
}
8282

83+
// CHECK-LABEL: @llvm_nvvm_barrier(
84+
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]])
85+
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) {
86+
// CHECK: call void @llvm.nvvm.barrier0()
87+
nvvm.barrier
88+
// CHECK: call void @llvm.nvvm.barrier.n(i32 %[[barId]])
89+
nvvm.barrier id = %barID
90+
// CHECK: call void @llvm.nvvm.barrier(i32 %[[barId]], i32 %[[numThreads]])
91+
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
92+
llvm.return
93+
}
94+
8395
// CHECK-LABEL: @llvm_nvvm_cluster_arrive
8496
llvm.func @llvm_nvvm_cluster_arrive() {
8597
// CHECK: call void @llvm.nvvm.barrier.cluster.arrive()
@@ -512,6 +524,13 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 2
512524
// CHECK: {ptr @kernel_func, !"maxntidz", i32 32}
513525
// CHECK: {ptr @kernel_func, !"minctasm", i32 16}
514526

527+
// -----
528+
529+
llvm.func @kernel_func(%numberOfThreads : i32) {
530+
// expected-error @below {{'nvvm.barrier' op barrier id is missing, it should be set between 0 to 15}}
531+
nvvm.barrier number_of_threads = %numberOfThreads
532+
}
533+
515534
// -----
516535
// expected-error @below {{'"nvvm.minctasm"' attribute must be integer constant}}
517536
llvm.func @kernel_func() attributes {nvvm.kernel,

0 commit comments

Comments
 (0)