Skip to content

Commit 6ca49d1

Browse files
committed
[Backend] Optimize membar insertion on hopper
- `mbarrier.try_wait` has same effects has bar. - Don't insert bar between mbarrier arrive/expect-tx/etc. - Distributed `mbarrier.arrive`'s arrive-count to as much warps as possible. - When all warps participates in `mbarrier.arrive`, don't insert a bar between it and previous `wgmma.mma_async` or `stmatrix`.
1 parent 134515d commit 6ca49d1

File tree

9 files changed

+118
-41
lines changed

9 files changed

+118
-41
lines changed

include/triton/Analysis/Membar.h

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct BlockInfo {
1919

2020
IntervalMapT syncReadIntervals;
2121
IntervalMapT syncWriteIntervals;
22+
IntervalMapT syncAtomicIntervals;
2223

2324
BlockInfo() = default;
2425

@@ -30,6 +31,9 @@ struct BlockInfo {
3031
for (auto &interval : other.syncWriteIntervals)
3132
syncWriteIntervals[interval.first].insert(interval.second.begin(),
3233
interval.second.end());
34+
for (auto &interval : other.syncAtomicIntervals)
35+
syncAtomicIntervals[interval.first].insert(interval.second.begin(),
36+
interval.second.end());
3337
return *this;
3438
}
3539

@@ -39,39 +43,67 @@ struct BlockInfo {
3943
err << " Read Intervals:\n";
4044
for (auto &[interval, ops] : syncReadIntervals) {
4145
err << " [" << interval.start() << ", " << interval.end() << "] ";
42-
for (auto &op : ops)
43-
err << op->getName() << " ";
46+
for (auto &op : ops) {
47+
op->dump();
48+
err << "\n";
49+
}
4450
err << "\n";
4551
}
4652
err << " Write Intervals:\n";
4753
for (auto &[interval, ops] : syncWriteIntervals) {
4854
err << " [" << interval.start() << ", " << interval.end() << "] ";
49-
for (auto &op : ops)
50-
err << op->getName() << " ";
55+
for (auto &op : ops) {
56+
op->dump();
57+
err << "\n";
58+
}
59+
err << "\n";
60+
}
61+
err << " Atomic Intervals:\n";
62+
for (auto &[interval, ops] : syncAtomicIntervals) {
63+
err << " [" << interval.start() << ", " << interval.end() << "] ";
64+
for (auto &op : ops) {
65+
op->dump();
66+
err << "\n";
67+
}
5168
err << "\n";
5269
}
5370
}
5471

5572
/// Returns true if intervals in two BlockInfo objects are intersected.
56-
bool isIntersected(const BlockInfo &other, MembarFilterFn filter) const {
57-
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals,
58-
filter) ||
59-
/*WAR*/
60-
isIntersected(syncReadIntervals, other.syncWriteIntervals, filter) ||
61-
/*WAW*/
62-
isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter);
73+
bool isIntersected(const BlockInfo &afterInfo, MembarFilterFn filter) const {
74+
// * Atomic, Write, Read
75+
// Atomic F, T, T
76+
// Write T, T, T
77+
// Rread T, T, F
78+
const auto &a0 = syncAtomicIntervals;
79+
const auto &r0 = syncReadIntervals;
80+
const auto &w0 = syncWriteIntervals;
81+
82+
// Note `*this`comes before `afterInfo`.
83+
const auto &a1 = afterInfo.syncAtomicIntervals;
84+
const auto &r1 = afterInfo.syncReadIntervals;
85+
const auto &w1 = afterInfo.syncWriteIntervals;
86+
87+
auto intersects = [&](const IntervalMapT &s0, const auto &...ss) {
88+
return (... || isIntersected(s0, ss, filter));
89+
};
90+
91+
return intersects(a0, w1, r1) || intersects(w0, a1, w1, r1) ||
92+
intersects(r0, a1, w1);
6393
}
6494

6595
/// Clears the intervals because a barrier is inserted.
6696
void sync() {
6797
syncReadIntervals.clear();
6898
syncWriteIntervals.clear();
99+
syncAtomicIntervals.clear();
69100
}
70101

71102
/// Compares two BlockInfo objects.
72103
bool operator==(const BlockInfo &other) const {
73104
return syncReadIntervals == other.syncReadIntervals &&
74-
syncWriteIntervals == other.syncWriteIntervals;
105+
syncWriteIntervals == other.syncWriteIntervals &&
106+
syncAtomicIntervals == other.syncAtomicIntervals;
75107
}
76108

77109
bool operator!=(const BlockInfo &other) const { return !(*this == other); }

include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,23 @@
22
#define TRITON_GPU_DIALECT_INTERFACES_H
33

44
#include "mlir/IR/OpDefinition.h"
5+
#include "mlir/Interfaces/SideEffectInterfaces.h"
56

67
// clang-format off
78
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
89
#include "triton/Dialect/TritonGPU/IR/OpInterfaces.h.inc"
910
#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc"
1011
// clang-format on
1112

13+
namespace mlir::MemoryEffects {
14+
// An atomic read or write on mbarrier:
15+
// - atomic rmw:
16+
// * mbarrier.arrive
17+
// * mbarrier.expect_tx
18+
// * cp.async.bulk.tensor
19+
// - atomic cas: mbarrier.try_wait
20+
// We don'y need to insert a `__syncthreads()` between atomic effects, but we
21+
// need if they were write effects.
22+
struct MBarAtomic : public Effect::Base<MBarAtomic> {};
23+
} // namespace mlir::MemoryEffects
1224
#endif // TRITON_GPU_DIALECT_INTERFACES_H

include/triton/Dialect/TritonGPU/IR/TritonGPUOpInterfaces.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITONGPU_OP_INTERFACES
33

44
include "mlir/IR/OpBase.td"
5+
include "mlir/Interfaces/SideEffectInterfaces.td"
56

67
def UpcastFpOpInterface : OpInterface<"UpcastFpOpInterface"> {
78
let description = [{
@@ -26,4 +27,9 @@ def UpcastFpOpInterface : OpInterface<"UpcastFpOpInterface"> {
2627
];
2728
}
2829

30+
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
31+
def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;
32+
33+
def MBarAtomic: MemoryEffect<"::mlir::MemoryEffects::MBarAtomic", SharedMemory, 0, PartialEffect>;
34+
2935
#endif // TRITONGPU_OP_INTERFACES

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
55
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
66
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
77
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
8+
include "triton/Dialect/TritonGPU/IR/TritonGPUOpInterfaces.td"
89
include "mlir/Dialect/Arith/IR/ArithBase.td"
910
include "triton/Dialect/Triton/IR/TritonTypes.td"
1011
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
@@ -19,9 +20,6 @@ include "mlir/Interfaces/ViewLikeInterface.td"
1920
//
2021
// Interfaces
2122
//
22-
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
23-
def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;
24-
2523
class TTG_Op<string mnemonic, list<Trait> traits = []> :
2624
Op<TritonGPU_Dialect, mnemonic,
2725
!listconcat(traits, [VerifyTensorLayoutsTrait])> {

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
3030
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
3131
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
3232
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
33+
include "triton/Dialect/TritonGPU/IR/TritonGPUOpInterfaces.td"
3334
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
3435
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
3536
include "mlir/IR/OpBase.td"
@@ -38,8 +39,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
3839
include "mlir/Interfaces/DestinationStyleOpInterface.td"
3940
include "mlir/Interfaces/ViewLikeInterface.td"
4041

41-
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
42-
def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;
4342
def TensorMemory : Resource<"::mlir::triton::nvidia_gpu::TensorMemory">;
4443

4544
class TTNG_Op<string mnemonic, list<Trait> traits = []> :
@@ -170,7 +169,7 @@ def TTNG_BarrierExpectOp : TTNG_Op<"barrier_expect"> {
170169

171170
let hasVerifier = 1;
172171
let arguments = (ins
173-
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$alloc,
172+
Arg<TTG_MemDescType, "", [MBarAtomic]>:$alloc,
174173
I32Attr:$size,
175174
I1:$pred
176175
);
@@ -198,7 +197,7 @@ def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [AttrSizedOperandSegments]> {
198197
}];
199198

200199
let arguments = (ins
201-
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>, MemWrite<SharedMemory>]>:$alloc,
200+
Arg<TTG_MemDescType, "", [MBarAtomic]>:$alloc,
202201
I32:$phase,
203202
Optional<I1>:$pred,
204203
Variadic<TTG_MemDescType>:$deps
@@ -245,7 +244,7 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> {
245244
}];
246245

247246
let arguments = (ins
248-
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>, MemWrite<SharedMemory>]>:$alloc,
247+
Arg<TTG_MemDescType, "", [MBarAtomic]>:$alloc,
249248
I32Attr:$count,
250249
Optional<I1>:$pred
251250
);
@@ -266,7 +265,7 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> {
266265
def TTNG_AsyncCopyMbarrierArriveOp : TTNG_Op<"async_copy_mbarrier_arrive"> {
267266
let summary = "arrive on mbarrier once all previously issued copies are completed";
268267
let arguments = (ins
269-
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
268+
Arg<TTG_MemDescType, "", [MBarAtomic]>:$barrier,
270269
UnitAttr:$noIncrement
271270
);
272271
let assemblyFormat = "$barrier attr-dict `:` qualified(type($barrier))";
@@ -288,7 +287,7 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local">
288287
let arguments = (ins
289288
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
290289
Variadic<I32>:$coord,
291-
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
290+
Arg<TTG_MemDescType, "", [MBarAtomic]>:$barrier,
292291
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$result,
293292
I1:$pred,
294293
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
@@ -362,7 +361,7 @@ def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather"> {
362361
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
363362
RankedTensorOf<[I32]>:$x_offsets,
364363
I32:$y_offset,
365-
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
364+
Arg<TTG_MemDescType, "", [MBarAtomic]>:$barrier,
366365
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$result,
367366
I1:$pred
368367
);

lib/Analysis/Membar.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "triton/Analysis/Membar.h"
22
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
33
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
4+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
45
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
56

67
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -165,7 +166,8 @@ void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) {
165166
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
166167
FuncBlockInfoMapT *funcBlockInfoMap,
167168
OpBuilder *builder) {
168-
if (isa<gpu::BarrierOp, triton::gpu::LocalBarrierOp>(op)) {
169+
if (isa<gpu::BarrierOp, triton::gpu::LocalBarrierOp,
170+
triton::nvidia_gpu::WaitBarrierOp>(op)) {
169171
// If the current op is a barrier, we sync previous reads and writes
170172
blockInfo->sync();
171173
return;
@@ -210,6 +212,12 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
210212
.syncReadIntervals[allocation->getAllocatedInterval(
211213
bufferId)]
212214
.insert(op);
215+
else if (isa<MemoryEffects::MBarAtomic>(
216+
effectInstance.getEffect()))
217+
curBlockInfo
218+
.syncAtomicIntervals[allocation->getAllocatedInterval(
219+
bufferId)]
220+
.insert(op);
213221
}
214222
}
215223
}
@@ -244,7 +252,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
244252
}
245253

246254
if (!curBlockInfo.syncReadIntervals.empty() ||
247-
!curBlockInfo.syncWriteIntervals.empty()) {
255+
!curBlockInfo.syncWriteIntervals.empty() ||
256+
!curBlockInfo.syncAtomicIntervals.empty()) {
248257
llvm::report_fatal_error(
249258
"scratch buffer operations should not have any shared memory "
250259
"dependencies");

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,6 @@ include "TritonAMDGPUAttrDefs.td"
4646
class TT_AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
4747
Op<TritonAMDGPU_Dialect, mnemonic, !listconcat(traits, [])>;
4848

49-
//
50-
// Interfaces
51-
//
52-
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
53-
def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;
54-
5549
//===----------------------------------------------------------------------===//
5650
// ExtractSliceOp
5751
//===----------------------------------------------------------------------===//

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,24 +233,38 @@ struct ArriveBarrierOpConversion
233233
LogicalResult
234234
matchAndRewrite(triton::nvidia_gpu::ArriveBarrierOp op, OpAdaptor adaptor,
235235
ConversionPatternRewriter &rewriter) const override {
236-
// TODO: Add phase result as needed.
237236
std::stringstream ptxAsm;
238-
ptxAsm << "@$0 mbarrier.arrive.shared::cta.b64 _, [$1]";
239-
if (op.getCount() > 1) {
240-
ptxAsm << ", " << op.getCount();
241-
}
242-
ptxAsm << ";";
237+
ptxAsm << "@$0 mbarrier.arrive.shared::cta.b64 _, [$1], $2;";
243238

244239
TritonLLVMOpBuilder b(op.getLoc(), rewriter);
245-
Value id = getThreadId(rewriter, op.getLoc());
246-
Value pred = b.icmp_eq(id, b.i32_val(0));
240+
Value pred = LLVM::NVIDIA::createElectPredicate(op.getLoc(), rewriter);
247241
if (op.getPred())
248242
pred = b.and_(pred, adaptor.getPred());
249243

244+
// Distribute arrive-count equally among participating warps.
245+
int count = op.getCount();
246+
int numWarps = triton::gpu::lookupNumWarps(op);
247+
int countPerWarp = count / numWarps;
248+
int remainderCount = count % numWarps;
249+
auto [_, warpId] = getLaneAndWarpId(rewriter, op.getLoc());
250+
Value remPred = b.icmp_ult(warpId, b.i32_val(remainderCount));
251+
if (countPerWarp < 1) {
252+
pred = b.and_(pred, remPred);
253+
}
254+
Value countVal;
255+
if (remainderCount) {
256+
countVal = b.select(remPred, b.i32_val(countPerWarp + 1),
257+
b.i32_val(countPerWarp));
258+
} else {
259+
countVal = b.i32_val(countPerWarp);
260+
}
261+
250262
PTXBuilder ptxBuilder;
251263
SmallVector<PTXBuilder::Operand *, 2> operands = {
252264
ptxBuilder.newOperand(pred, "b"),
253-
ptxBuilder.newOperand(adaptor.getAlloc(), "r")};
265+
ptxBuilder.newOperand(adaptor.getAlloc(), "r"),
266+
ptxBuilder.newOperand(countVal, "r"),
267+
};
254268

255269
auto arriveOp = *ptxBuilder.create<>(ptxAsm.str());
256270
arriveOp(operands, /*onlyAttachMLIRArgs=*/true);

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ class TritonLLVMConversionTarget : public ConversionTarget {
6868
}
6969
};
7070

71+
bool membarFilter(Operation *beforeOp, Operation *afterOp) {
72+
if (isa<triton::nvidia_gpu::WarpGroupDotOp, triton::gpu::LocalStoreOp>(
73+
beforeOp)) {
74+
if (auto mbarArriveOp =
75+
dyn_cast<triton::nvidia_gpu::ArriveBarrierOp>(afterOp)) {
76+
auto numWarps = triton::gpu::lookupNumWarps(afterOp);
77+
auto numArrive = mbarArriveOp.getCount();
78+
return numArrive >= numWarps;
79+
}
80+
}
81+
return false;
82+
}
83+
7184
struct ConvertTritonGPUToLLVM
7285
: public triton::impl::ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
7386
using ConvertTritonGPUToLLVMBase::ConvertTritonGPUToLLVMBase;
@@ -86,7 +99,7 @@ struct ConvertTritonGPUToLLVM
8699
ModuleAllocation allocation(
87100
mod, mlir::triton::nvidia_gpu::getNvidiaAllocationAnalysisScratchSizeFn(
88101
targetInfo));
89-
ModuleMembarAnalysis membarPass(&allocation);
102+
ModuleMembarAnalysis membarPass(&allocation, membarFilter);
90103
membarPass.run();
91104

92105
mlir::LowerToLLVMOptions option(context);

0 commit comments

Comments
 (0)