Skip to content

Aref Automatic Warp Specialization [AutoWS] Implementation #6689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: aref_auto_ws
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ target_link_libraries(triton-llvm-opt PRIVATE
LLVMSupport
LLVMOption
LLVMCodeGen
${dialect_libs}
${conversion_libs}
${triton_libs}
)
export_executable_symbols_for_plugins(triton-llvm-opt)

Expand Down
1 change: 1 addition & 0 deletions include/triton/Analysis/Membar.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class MembarAnalysis {
SmallVector<VirtualBlock> &successors);

void insertBarrier(Operation *operation, OpBuilder *builder);
bool isBarrier(Operation *op);

private:
Allocation *allocation = nullptr;
Expand Down
7 changes: 6 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,10 @@ Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
// Hardware Indices
// -----------------------------------------------------------------------

// If an operation is contained within a warp specialize region, this returns
// the starting warp of the group
std::optional<int> getWarpGroupStart(Block *block);

// If an operation is contained within a warp specialize region, this returns
// the thread ID offset of that warpgroup.
std::optional<int> getWarpGroupStartThreadId(Block *block);
Expand Down Expand Up @@ -679,7 +683,8 @@ SmallVector<SmallVector<unsigned>> emitOffsetForLayout(Attribute layout,
// access the corresponding element, starting from the inner dimension.
SmallVector<SmallVector<Value>>
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Attribute layout, RankedTensorType type, bool withCTAOffset);
Attribute layout, RankedTensorType type, bool withCTAOffset,
std::optional<int> warpGroupStart = std::nullopt);

// Emits IR to load data from shared memory into registers, or to store data
// from registers into shared memory.
Expand Down
19 changes: 18 additions & 1 deletion include/triton/Conversion/TritonToTritonGPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,36 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">,

Option<"numStages", "num-stages",
"int32_t", /*default*/"1",
"number of stages">,
Option<"threadsPerWarp", "threads-per-warp",
"int32_t", /*default*/"32",
"number of threads per warp">,
Option<"numCTAs", "num-ctas",
"int32_t", /*default*/"1",
"number of ctas in a cga">,
Option<"warpSpecialized", "warp-specialized",
"bool", /*default*/"false",
"whether it is warp specialized">,
Option<"useTtgWs", "use ttg-warp-specialize",
"bool", /*default*/"false",
"whether to use ttg.warp_specialized">,
Option<"mathWGPipe", "math-wg-pipe",
"bool", /*default*/"false",
"whether to pipeline math WG">,
Option<"target", "target",
"std::string", /*default*/"\"\"",
"the GPU target, e.g., cuda:80, hip:gfx942">,
Option<"enableSourceRemat", "enable-source-remat",
"bool", /*default*/"false",
"enable trivial source rematerialization">,
ListOption<"wgName", "wg-name",
"std::string", "warp-group name">,
ListOption<"wgStartWarp", "wg-start-warp",
"int32_t", "warp-group start">,
ListOption<"wgNumWarps", "wg-num-warps",
"int32_t", "warp-group number of warps">
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <vector>

namespace mlir {

Expand All @@ -16,10 +18,13 @@ namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();

// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(const std::string &target, int numWarps,
int threadsPerWarp = 32, int numCTAs = 1,
bool enableSourceRemat = false);
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass(
const std::string &target, int numWarps, int threadsPerWarp = 32,
int numCTAs = 1, bool enableSourceRemat = false,
int numStages = 1,
bool warpSpecialized = false, bool useTtgWs = false,
bool mathWGPipe = false,
const std::vector<std::tuple<std::string, int, int>> &wgSpec = {});

} // namespace triton
} // namespace mlir
Expand Down
9 changes: 9 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ namespace mlir::triton::gpu {

constexpr static char AttrMaxRegistersName[] = "ttg.maxnreg";
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
constexpr static char AttrNumStagesName[] = "ttg.num-stages";
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
constexpr static char AttrTargetName[] = "ttg.target";
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
constexpr static char AttrWarpSpecializedName[] = "ttg.warp-specialized";
constexpr static char AttrUseTtgWsName[] = "ttg.use-ttg-ws";
constexpr static char AttrMathWGPipeName[] = "ttg.math-wg-pipe";

// Find the contextual number of warps on which this operation is executed.
int lookupNumWarps(Operation *op);
Expand All @@ -56,6 +60,11 @@ std::optional<int> maybeLookupNumWarps(Operation *op);
// Utility to find the number of threads per warp
int lookupThreadsPerWarp(OpBuilder &rewriter);

// Same behaviour as lookupNumWarps, except for module/func ops return
// ttg.total-num-warps attribute rather than ttg.num-warps
int lookupTotalNumWarps(Operation *op);
std::optional<int> maybeLookupTotalNumWarps(Operation *op);

template <typename Key, typename Value> class Cache {
public:
std::optional<Value> get(const Key &key) {
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1362,4 +1362,6 @@ def TTG_SharedMemorySpace : AttrDef<TritonGPU_Dialect, "SharedMemorySpace"> {
}];
}

def TTG_TypeArray : ArrayOfAttr<TritonGPU_Dialect, "TypeArray", "type_array", "Type"> {}

#endif
4 changes: 4 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def TritonGPU_Dialect : Dialect {
let dependentDialects = [
"triton::TritonDialect",
"mlir::gpu::GPUDialect",
"tensor::TensorDialect",
];

let extraClassDeclaration = [{
Expand All @@ -26,7 +27,10 @@ def TritonGPU_Dialect : Dialect {
LinearEncodingAttr toLinearEncoding(ArrayRef<int64_t> shape, Attribute layout);

static int getNumCTAs(ModuleOp mod);
static int getNumStages(ModuleOp mod);
static int getThreadsPerWarp(ModuleOp mod);
static bool isWarpSpecialized(ModuleOp mod);
static bool isMathWGPipe(ModuleOp mod);

private:
LinearLayoutCache llCache;
Expand Down
43 changes: 43 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,52 @@ def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::Mod
let description = "For AsyncCopyGlobalToLocal ops where the shared encoding's vec is less than "
"the blocked encoding's sizePerThread, this pass improves coalescing by clipping the "
"sizePerThread value";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}

def TritonGPUAnalyzeWarpSpecialization : Pass</*cli-arg*/"tritongpu-analyze-warp-specialization", /*Op*/"mlir::ModuleOp"> {
let summary = "Analyzes the warp specialization pattern";
let description = [{
This pass analyzes the warp specialization pattern.
}];

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}


def TritonGPUSplitWarpGroupLoops : Pass</*cli-arg*/"tritongpu-split-warp-group-loops", /*Op*/"mlir::ModuleOp"> {
let summary = "Split warp group loops for warp specialization";
let description = [{
This pass splits the loop into a loop for each partition, based on warp specialization analysis.
}];

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect",
"mlir::NVVM::NVVMDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"3",
"number of warp specialization pipeline stages">,
Option<"mmaDepth", "mma-depth",
"int32_t", /*default*/"2",
"number of MMA pipeline stages">
];
}

def TritonGPUFMHAMathLoopPipeline: Pass<"triton-nvidia-fmha-math-loop-pipeline", "mlir::ModuleOp"> {
let summary = "fmha math loop pipelining";

let description = [{
Pipeline the FMHA math loop in TritonNvidiaGPUDialect.
}];

let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect",
];
}


#endif
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ createSingleBufferView(TBuilder &builder, Value alloc, int idx) {
builder.template create<arith::ConstantIntOp>(alloc.getLoc(), idx, 32));
}

namespace gpu {
scf::ForOp lowerTMADescriptors(scf::ForOp forOp, int maxStage);
}

} // namespace triton
} // namespace mlir

Expand Down
4 changes: 4 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"

Expand Down Expand Up @@ -247,6 +248,9 @@ void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
// Return true if two value sets may refer to the same allocation.
bool mayAliasAllocations(const DenseSet<Value> &lhs,
const DenseSet<Value> &rhs);

// inserts barrier, interface is aligned with getThreadId(rewriter, loc)
Operation* insertBarrier(OpBuilder &rewriter, Location loc);
} // namespace mlir

namespace mlir::triton {
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttng)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttng)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttng)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttng)
add_mlir_doc(TritonNvidiaGPUDialect TritonNvidiaGPUDialect dialects/ -gen-dialect-doc)
add_mlir_doc(TritonNvidiaGPUOps TritonNvidiaGPUOps dialects/ -gen-op-doc)
add_public_tablegen_target(TritonNvidiaGPUTableGen)
Expand Down
1 change: 1 addition & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h"

#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ def TritonNvidiaGPU_Dialect : Dialect {
"triton::TritonDialect",
"triton::gpu::TritonGPUDialect",
"mlir::gpu::GPUDialect",
"tensor::TensorDialect",
];

let extraClassDeclaration = [{
void registerTypes();
}];
let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
let usePropertiesForAttributes = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
"void",
"addCompletionBarrier",
(ins "::mlir::Value":$barrier, "::mlir::Value":$pred)>,
InterfaceMethod<"Get mutable completion barriers of this MMAv5 op.",
"::mlir::MutableOperandRange",
"getBarriersMutable">,
InterfaceMethod<"Get mutable completion barriers predicates of this MMAv5 op.",
"::mlir::MutableOperandRange",
"getBarrierPredsMutable">,
InterfaceMethod<"Return the accumulator.",
"::mlir::TypedValue<::mlir::triton::gpu::MemDescType>",
"getAccumulator">,
Expand Down
Loading
Loading