Skip to content

Commit

Permalink
[BW] Improve tmem allocation to handle row packing. (#110)
Browse files Browse the repository at this point in the history
Support packing multiple allocations along rows.
This changes from interval tracking to bitmap tracking of the memory
to allow handling allocating along two dimensions.
  • Loading branch information
ThomasRaoux authored Nov 2, 2024
1 parent 6bb2b6f commit a93111c
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 89 deletions.
9 changes: 8 additions & 1 deletion include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ struct TensorMemory : public SideEffects::Resource::Base<TensorMemory> {
StringRef getName() final { return "<TensorMemory>"; }
};

int getNumTMemColumns(MemDescType memDescType);
struct TMemAllocation {
TMemAllocation(int numCols, int numRows)
: numCols(numCols), numRows(numRows) {}
int numRows;
int numCols;
};

TMemAllocation getTmemAllocSizes(MemDescType memDescType);

Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
ArrayRef<int64_t> shape, unsigned numWarps,
Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ namespace mlir {
namespace triton {
namespace nvidia_gpu {

int getNumTMemColumns(MemDescType memDescType) {
static constexpr int numTmemRows = 128;

TMemAllocation getTmemAllocSizes(MemDescType memDescType) {
const int rowSizeInBytes = 4;
auto shapePerCTA = triton::gpu::getShapePerCTA(memDescType);
if (isa<TensorMemoryScalesEncodingAttr>(memDescType.getEncoding())) {
Expand All @@ -54,7 +56,7 @@ int getNumTMemColumns(MemDescType memDescType) {
int k = shapePerCTA[1];
int m = shapePerCTA[0];
int numColumn = ceil<int>(m, 32) * ceil<int>(k, 4);
return numColumn;
return TMemAllocation(numColumn, numTmemRows);
}
assert(isa<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
memDescType.getEncoding()) &&
Expand All @@ -67,7 +69,7 @@ int getNumTMemColumns(MemDescType memDescType) {
isUnpacked ? rowSizeInBytes
: memDescType.getElementType().getIntOrFloatBitWidth() / 8;
int sizeInBytes = product(shapePerCTA) * elementSizeInBytes;
int numRows = 128;
int numRows = numTmemRows;
// BlockM of 64 is and interleaved format, where for single message only the
// first 16 rows are used. For multiple blocks, the rows are interleaved, i.e.
// 0 N/2 N
Expand All @@ -90,7 +92,7 @@ int getNumTMemColumns(MemDescType memDescType) {
if (blockM == 64 && isSingleBlock)
numRows = 64;
int numColumn = ceil<int>(sizeInBytes, (numRows * rowSizeInBytes));
return numColumn;
return TMemAllocation(numColumn, numRows);
}

Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
Expand Down
217 changes: 171 additions & 46 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/MapVector.h"

#define GEN_PASS_CLASSES
Expand All @@ -18,6 +19,87 @@ using namespace triton::nvidia_gpu;

namespace {

// Granularity of row allocations.
static constexpr int allocGranularity = 64;
struct TMemChunk {
int startRow;
int startCol;
int numCols;
int numRows;
};

// Use a simple bitmap to track memory usage. This is a slow but it allows us to
// handle 2D memory without extra algorithmic complexity. The number of
// allocations is expected to be small so the compile time is unlikely to be a
// problem.
struct MemoryBitMap {
MemoryBitMap() : elements(512 * kNumRows, false) {}
void free(const TMemChunk &chunk) {
for (int i = 0; i < chunk.numCols; i++) {
for (int j = 0; j < chunk.numRows; j++) {
setUsed(chunk.startRow + j, chunk.startCol + i, false);
}
}
}
void alloc(const TMemChunk &chunk) {
for (int i = 0; i < chunk.numCols; i++) {
for (int j = 0; j < chunk.numRows; j++) {
setUsed(chunk.startRow + j, chunk.startCol + i, true);
}
}
}

TMemChunk findFirstFit(TMemAllocation allocSize,
std::optional<int> rowIdConstraint) {
int numRows = allocSize.numRows / allocGranularity;
assert(kNumRows - numRows >= 0);
assert(allocSize.numRows % allocGranularity == 0);
int startCol = 0;
while (1) {
if ((startCol + allocSize.numCols) * numRows >= elements.size())
elements.resize(2 * elements.size(), false);

// Iterate over possible starting rows
for (int startRow = 0; startRow <= kNumRows - numRows; ++startRow) {
if (rowIdConstraint && *rowIdConstraint != startRow)
continue;
bool fits = true;

// Check if the block starting at (startRow, startCol) is free
for (int i = 0; i < allocSize.numCols && fits; ++i) {
for (int j = 0; j < numRows; ++j) {
if (isUsed(startRow + j, startCol + i)) {
fits = false;
break;
}
}
}

// If a suitable block is found, return it
if (fits) {
TMemChunk chunk;
chunk.startRow = startRow;
chunk.startCol = startCol;
chunk.numRows = numRows;
chunk.numCols = allocSize.numCols;
return chunk;
}
}
startCol++;
}
return TMemChunk();
}

private:
bool isUsed(int row, int col) const { return elements[row + col * kNumRows]; }
void setUsed(int row, int col, bool used) {
elements[row + col * kNumRows] = used;
}

static constexpr int kNumRows = 2;
std::vector<bool> elements;
};

static Interval<int> getLiveIntervals(Value value, Liveness &liveness,
DenseMap<Operation *, int> &operationId) {
auto liveOperations = liveness.resolveLiveness(value);
Expand Down Expand Up @@ -48,86 +130,129 @@ static Interval<int> getLiveIntervals(Value value, Liveness &liveness,
return Interval(minId, maxId);
}

static void
updateFreeIntervals(std::set<Interval<int>> &freeIntervals,
Interval<int> liveInterval,
std::map<int, Interval<int>> &intervalLiverangeEnd) {
static void updateMap(MemoryBitMap &memoryMap, Interval<int> liveInterval,
std::map<int, TMemChunk> &intervalLiverangeEnd) {
int start = liveInterval.start();
// Add any dead liverange to the list of free intervals.
for (auto it = intervalLiverangeEnd.begin();
it != intervalLiverangeEnd.end();) {
if (it->first > start)
break;
freeIntervals.insert(it->second);
memoryMap.free(it->second);
it = intervalLiverangeEnd.erase(it);
}
// Coaelesce free intervals.
auto it = freeIntervals.begin();
auto next = std::next(it);
while (next != freeIntervals.end()) {
if (it->end() == next->start()) {
// Merge the two intervals.
Interval<int> newInterval(it->start(), next->end());
freeIntervals.erase(it);
freeIntervals.erase(next);
it = freeIntervals.insert(newInterval).first;
} else {
it = next;
}
next = std::next(it);
}
}

static Interval<int> allocFirstFit(std::set<Interval<int>> &freeIntervals,
int allocSize) {
for (auto it : freeIntervals) {
if (it.size() >= allocSize) {
Interval<int> allocatedInterval(it.start(), it.start() + allocSize);
Interval<int> newInterval(it.start() + allocSize, it.end());
freeIntervals.erase(it);
if (newInterval.size() > 0)
freeIntervals.insert(newInterval);
return allocatedInterval;
}
static TMemChunk allocFirstFit(MemoryBitMap &memoryMap,
TMemAllocation allocSize,
std::optional<int> rowIdConstraint) {
TMemChunk chunk = memoryMap.findFirstFit(allocSize, rowIdConstraint);
memoryMap.alloc(chunk);
return chunk;
}

static Operation *getAlloc(Value value) {
Operation *op = value.getDefiningOp();
while (isa<triton::gpu::MemDescSubviewOp>(op)) {
op = op->getResult(0).getDefiningOp();
}
llvm::report_fatal_error("Failed to allocate memory.");
return Interval<int>();
assert(isa<triton::nvidia_gpu::TMEMAllocOp>(op) && "Expected a TMEMAllocOp");
return op;
}

class RowIdConstraints {
llvm::EquivalenceClasses<Operation *> dependentAllocs;
llvm::SmallDenseMap<Operation *, int> rowIndex;

public:
void joinOps(Operation *op1, Operation *op2) {
dependentAllocs.unionSets(op1, op2);
}

std::optional<int> getRowIdConstraint(Operation *op) {
auto it = dependentAllocs.findLeader(op);
if (it == dependentAllocs.member_end())
return std::nullopt;
auto rowIt = rowIndex.find(*it);
if (rowIt == rowIndex.end())
return std::nullopt;
return rowIt->second;
}

void addConstraints(Operation *op, int rowId) {
auto it = dependentAllocs.findLeader(op);
if (it == dependentAllocs.member_end())
return;
rowIndex[*it] = rowId;
}
};

static int
allocateTMem(Operation *parentOp,
DenseMap<triton::nvidia_gpu::TMEMAllocOp, int> &offsets) {
SmallVector<triton::nvidia_gpu::TMEMAllocOp> allocs;
DenseMap<Operation *, int> operationId;
llvm::EquivalenceClasses<Operation *> dependentAllocs;
RowIdConstraints rowIdConstraints;
parentOp->walk<WalkOrder::PostOrder>([&](Operation *op) {
operationId[op] = operationId.size();
if (auto alloc = dyn_cast<triton::nvidia_gpu::TMEMAllocOp>(op)) {
allocs.push_back(alloc);
}
if (auto mmaOp = dyn_cast<triton::nvidia_gpu::TCGen5MMAOp>(op)) {
if (isa<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
mmaOp.getA().getType().getEncoding())) {
TMemAllocation allocSize = getTmemAllocSizes(mmaOp.getA().getType());
if (allocSize.numRows == 64) {
// HW restriction, the A alloc and accumulator needs to be in the same
// rows.
rowIdConstraints.joinOps(getAlloc(mmaOp.getA()),
getAlloc(mmaOp.getD()));
} else {
// TODO: we need to handle cases where the format is blockM and we
// have multiple blocks.
assert((cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
mmaOp.getA().getType().getEncoding())
.getBlockM() != 64 &&
cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
mmaOp.getD().getType().getEncoding())
.getBlockM() != 64) &&
"interleaved layout with TMEM operand is not supported yet.");
}
}
}
});
int totalMemorySize = 0;
std::set<Interval<int>> freeIntervals;
MemoryBitMap memoryMap;
Liveness liveness(parentOp);
// Start with an infinit free interval.
freeIntervals.insert(Interval<int>(0, std::numeric_limits<int>::max()));
std::map<int, Interval<int>> intervalLiverangeEnd;
std::map<int, TMemChunk> intervalLiverangeEnd;
// Implement a linear scan first fit algorithm. We expect that fragmentation
// won't be a problem, if it is this should be revisited.
for (triton::nvidia_gpu::TMEMAllocOp alloc : allocs) {
Interval<int> liveInterval = getLiveIntervals(alloc, liveness, operationId);
auto memDescType = alloc.getType();
int allocSize = getNumTMemColumns(memDescType);
updateFreeIntervals(freeIntervals, liveInterval, intervalLiverangeEnd);
TMemAllocation allocSize = getTmemAllocSizes(memDescType);
updateMap(memoryMap, liveInterval, intervalLiverangeEnd);

// Find first fit.
Interval<int> intervalAlloc = allocFirstFit(freeIntervals, allocSize);
intervalLiverangeEnd[liveInterval.end()] = intervalAlloc;
int offset = intervalAlloc.start();
std::optional<int> rowIdConstraint =
rowIdConstraints.getRowIdConstraint(alloc);
TMemChunk chunkAllocated =
allocFirstFit(memoryMap, allocSize, rowIdConstraint);
// currently naively constraint allocs based on the first one we find.
rowIdConstraints.addConstraints(alloc, chunkAllocated.startRow);
intervalLiverangeEnd[liveInterval.end()] = chunkAllocated;
int colOffset = chunkAllocated.startCol;
int rowOffset = chunkAllocated.startRow * 16;

alloc->setAttr(
"tensor_memory_offset",
IntegerAttr::get(IntegerType::get(parentOp->getContext(), 32), offset));
totalMemorySize = std::max(totalMemorySize, offset + allocSize);
"tensor_memory_col_offset",
IntegerAttr::get(IntegerType::get(parentOp->getContext(), 32),
colOffset));
alloc->setAttr(
"tensor_memory_row_offset",
IntegerAttr::get(IntegerType::get(parentOp->getContext(), 32),
rowOffset));
totalMemorySize = std::max(totalMemorySize, colOffset + allocSize.numCols);
}
return totalMemorySize;
}
Expand Down
4 changes: 4 additions & 0 deletions python/test/unit/language/blackwell_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def matmul_kernel( #
@pytest.mark.parametrize("NUM_WARPS", [4, 8])
def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, NUM_CTAS,
device):
if BLOCK_M == 64 and BLOCK_N == 16 and BLOCK_K == 16 and NUM_STAGES == 4 and dtype_src_str == "float16":
pytest.skip(
"Skipping tests failing due to suspected ptxas bug: https://triton-lang.slack.com/archives/C07FLUE9U8N/p1730443207543549"
)
if dtype_src_str == "float8e5" and BLOCK_K == 16:
pytest.skip("Skipping cases small K for float8")
if dtype_src_str == "float32" and dtype_dst_str == "float16":
Expand Down
8 changes: 4 additions & 4 deletions python/test/unit/language/test_compile_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ def k_loop(a_base, b_base, out, k_tiles):
target=GPUTarget("cuda", 100, 32))
ttgir = k.asm["ttgir"]

pattern = (r"%(?P<C0>\w+) = arith.constant dense<0.000000e\+00>"
pattern = (r"%(?P<TMEM_BASE>\w+) = arith.constant dense<0.000000e\+00>"
r"(.|\n)*?"
r"scf\.for.* iter_args\(%(?P<ACC>\w+) = %(?P=C0)"
r"%(?P<TMEM>\w+) = triton_nvidia_gpu\.tmem_alloc %(?P=TMEM_BASE)"
r"(.|\n)*?"
r"scf\.for"
r"(.|\n)*?"
r"%(?P<A>\w+) = tt\.load"
r"(.|\n)*?"
Expand All @@ -106,8 +108,6 @@ def k_loop(a_base, b_base, out, k_tiles):
r"(.|\n)*?"
r"%(?P<B_SHMEM>\w+) = triton_gpu\.local_alloc %(?P=B)"
r"(.|\n)*?"
r"%(?P<TMEM>\w+) = triton_nvidia_gpu\.tmem_alloc %(?P=ACC)"
r"(.|\n)*?"
r"triton_nvidia_gpu\.tc_gen5_mma %(?P=A_SHMEM), %(?P=B_SHMEM), %(?P=TMEM)"
r"(.|\n)*?"
r"scf\.yield")
Expand Down
3 changes: 3 additions & 0 deletions python/tutorials/06-fused-attention-blackwell-expt.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,9 @@ def backward(ctx, do):
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 1
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
if ctx.HEAD_DIM == 256:
BLOCK_N1 = 64
BLOCK_M2 = 64
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k.cpu()
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1628,7 +1628,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: tcgen05.wait::st.sync.aligned
tt.func public @tensor_memory_st(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
%0 = triton_nvidia_gpu.tmem_alloc {tensor_memory_offset = 0 : i32} : () -> !tt.memdesc<128x128xf32, #tmem, #triton_nvidia_gpu.tensor_memory, mutable>
%0 = triton_nvidia_gpu.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !tt.memdesc<128x128xf32, #tmem, #triton_nvidia_gpu.tensor_memory, mutable>
%true = arith.constant true
triton_nvidia_gpu.tmem_store %cst_0, %0, %true : tensor<128x128xf32, #blocked1> -> !tt.memdesc<128x128xf32, #tmem, #triton_nvidia_gpu.tensor_memory, mutable>
tt.return
Expand Down
Loading

0 comments on commit a93111c

Please sign in to comment.