Skip to content

Commit

Permalink
[Backend] Tiny cleanup/refactor (NFC) (triton-lang#5340)
Browse files Browse the repository at this point in the history
Introduce `emitHardwareTuple` helper that emits the code to compute the
blockId, warpId, and laneId for a thread and returns them. This PR uses
this helper in a few places.
  • Loading branch information
Mogball authored Dec 4, 2024
1 parent 6e24f72 commit 147d332
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
6 changes: 6 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,12 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
return idx;
}

// Emit code to compute the (blockId, warpId, laneId) for the current thread.
std::tuple</*blockId=*/Value, /*warpId=*/Value, /*laneId=*/Value>
emitHardwareTuple(Location loc, RewriterBase &rewriter,
const TargetInfoBase &target, bool withCTAOffset,
unsigned threadsPerWarp);

// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
//
Expand Down
35 changes: 21 additions & 14 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
return outIndices;
}

std::tuple<Value, Value, Value> emitHardwareTuple(Location loc,
RewriterBase &rewriter,
const TargetInfoBase &target,
bool withCTAOffset,
unsigned threadsPerWarpCst) {
Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(threadsPerWarpCst);
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);
Value blockId =
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0);
return {blockId, warpId, laneId};
}

SmallVector<SmallVector<Value>>
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Attribute layout, RankedTensorType type, bool withCTAOffset) {
Expand All @@ -116,12 +130,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(ll->getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);
Value blockId =
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0);
auto [blockId, warpId, laneId] = emitHardwareTuple(
loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane));
unsigned rank = shape.size();
SmallVector<SmallVector<Value>> ret;
// Linear layout function is split in two parts below:
Expand Down Expand Up @@ -214,10 +224,9 @@ bool emitTransferBetweenRegistersAndShared(
std::min(regToSharedLayout->getNumConsecutiveInOut(),
maxVecElems.value_or(std::numeric_limits<int>::max()));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);
auto [blockId, warpId, laneId] =
emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false,
regToSharedLayout->getInDimSize(kLane));

int numElems = regToSharedLayout->getInDimSize(kRegister);
auto vecTy = vec_ty(elemLlvmTy, vecElems);
Expand Down Expand Up @@ -625,10 +634,8 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
auto instrShape = mmaLayout.getInstrShape();
SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
auto [blockId, warpId, laneId] = emitHardwareTuple(
loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
Expand Down
3 changes: 0 additions & 3 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
std::optional<LinearLayout>
BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
assert(shape.size() == getOrder().size());

int rank = shape.size();
MLIRContext *ctx = getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);

const auto &order = getOrder();
LinearLayout ctaLayout =
Expand Down

0 comments on commit 147d332

Please sign in to comment.