Skip to content

Commit

Permalink
[BACKEND][AMD] Disable linear layout due to perf regression (#4126)
Browse files Browse the repository at this point in the history
We have identified a 20% perf regression in our downstream flash
attention perf kernel after switching to linear layout. Initial analysis
shows register pressure is increased to cause spills. Further analysis
is still ongoing.

So this commit introduces a minimal way to selectively disable linear
layout only on AMD backend to avoid affecting NVIDIA backend while
continuing bring it up on AMD side.
  • Loading branch information
antiagainst authored Jun 12, 2024
1 parent 6eecbd9 commit e8bc45d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
5 changes: 5 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class TargetInfoBase {
StringRef message, StringRef file, StringRef func,
int line) const = 0;

// Whether to enable linear layout. This is a per-backend temporary escape
// hatch to disable linear layout while figuring out issues. Eventually we
// want to enable linear layout everywhere and delete this control.
virtual bool enableLinearLayout() const { return true; }

virtual ~TargetInfoBase() {}
};
} // namespace mlir::triton
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
bool allowLL = true) {
// Eventually the LinearLayout path will be the only one. For now we allow
// both paths so we can test that they produce the same results.
if (allowLL) {
if (allowLL && target.enableLinearLayout()) {
std::optional<SmallVector<SmallVector<Value>>> llOffsets =
emitIndicesUsingLinearLayouts(loc, rewriter, target, layout, type,
withCTAOffset);
Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
StringRef message, StringRef file, StringRef func,
int line) const override;

bool enableLinearLayout() const override { return false; }

private:
void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args,
ConversionPatternRewriter &rewriter, bool useStdErr) const;
Expand Down

0 comments on commit e8bc45d

Please sign in to comment.