Skip to content

Commit

Permalink
[Proton][Dialect] Add Initial Frontend and Target Backend Infrastruct…
Browse files Browse the repository at this point in the history
…ure For Proton Dialect (#5506)

Implement initial basic infrastructure for the Proton Dialect added in
#5119

This PR extends the initial boilerplate MLIR Dialect code to the Triton
frontend and target backends - currently just lowered to a no-op.
  • Loading branch information
CRobeck authored Jan 6, 2025
1 parent 2b06b2c commit dcad5ac
Show file tree
Hide file tree
Showing 13 changed files with 157 additions and 2 deletions.
1 change: 1 addition & 0 deletions lib/Conversion/TritonToTritonGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_triton_library(TritonToTritonGPU
MLIRPass
MLIRTransforms
TritonIR
ProtonIR
TritonGPUIR
TritonGPUTransforms
)
15 changes: 14 additions & 1 deletion lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"

namespace {

using namespace mlir;
Expand Down Expand Up @@ -555,7 +557,17 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::DotScaledOp>, GenericOpPattern<triton::CallOp>,
TritonFuncOpPattern>(typeConverter, context);
}

// Proton patterns
// NOTE: Because Proton's inputs are scalars and not tensors this conversion
// isn't strictly nessessary however you could envision a case where we pass in
// tensors in for Triton object specific tracing operations in which case we
// would need to fill in the OpConversionPattern
void populateProtonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<GenericOpPattern<triton::proton::RecordOp>>(typeConverter,
context);
}
//
// SCF patterns
//
Expand Down Expand Up @@ -770,6 +782,7 @@ class ConvertTritonToTritonGPU
populateArithPatternsAndLegality(typeConverter, patterns, target);
populateMathPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns, numCTAs);
populateProtonPatterns(typeConverter, patterns);
// TODO: can we use
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(typeConverter, patterns);
Expand Down
10 changes: 9 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/SourceMgr.h"

#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"

namespace {

namespace py = pybind11;
Expand Down Expand Up @@ -235,7 +237,8 @@ void init_triton_ir(py::module &&m) {
registry.insert<TritonDialect, ::mlir::triton::gpu::TritonGPUDialect,
math::MathDialect, arith::ArithDialect, scf::SCFDialect,
::mlir::gpu::GPUDialect, cf::ControlFlowDialect,
LLVM::LLVMDialect, mlir::ub::UBDialect>();
::mlir::triton::proton::ProtonDialect, LLVM::LLVMDialect,
mlir::ub::UBDialect>();
mlir::LLVM::registerInlinerInterface(registry);
registerBuiltinDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
Expand Down Expand Up @@ -1654,6 +1657,11 @@ void init_triton_ir(py::module &&m) {
std::vector<int32_t> &tensorShape) -> Value {
return self.create<MakeTensorDescOp>(base, shape, strides,
tensorShape);
})
// Proton Ops
.def("create_proton_record",
[](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void {
self.create<mlir::triton::proton::RecordOp>(isStart, regionId);
});

py::class_<PassManager>(m, "pass_manager", py::module_local())
Expand Down
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ add_triton_library(TritonAMDGPUToLLVM
LINK_LIBS PUBLIC
TritonGPUToLLVM
TritonAMDGPUIR
TritonProtonToLLVM
)
6 changes: 6 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"

#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h"

namespace mlir::triton {
#define GEN_PASS_DEF_CONVERTTRITONAMDGPUTOLLVM
#include "TritonAMDGPUToLLVM/Passes.h.inc"
Expand Down Expand Up @@ -228,6 +230,10 @@ struct ConvertTritonAMDGPUToLLVM
patterns);
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
targetInfo, commonBenefit);

mlir::triton::proton::populateRecordOpToLLVMPattern(
typeConverter, patterns, targetInfo, commonBenefit);

mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ add_triton_library(TritonNVIDIAGPUToLLVM

LINK_LIBS PUBLIC
TritonGPUToLLVM
TritonProtonToLLVM
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"

#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h"

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_CONVERTTRITONGPUTOLLVM
Expand Down Expand Up @@ -149,6 +151,8 @@ struct ConvertTritonGPUToLLVM
targetInfo, benefit);
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H
#define TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"

using namespace mlir;
using namespace mlir::triton;

namespace mlir {
namespace triton {
namespace proton {
void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);
} // namespace proton
} // namespace triton
} // namespace mlir

#endif
1 change: 1 addition & 0 deletions third_party/proton/dialect/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(Dialect)
add_subdirectory(TritonProtonToLLVM)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
add_triton_library(TritonProtonToLLVM
RecordOpToLLVM.cpp

LINK_LIBS PUBLIC
ProtonIR
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h"

namespace {

struct RecordOpConversion
: public ConvertOpToLLVMPattern<mlir::triton::proton::RecordOp> {
explicit RecordOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit)
: mlir::ConvertOpToLLVMPattern<mlir::triton::proton::RecordOp>(
typeConverter, benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(mlir::triton::proton::RecordOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.eraseOp(op);
return success();
}

protected:
const TargetInfoBase &targetInfo;
};

} // namespace

void mlir::triton::proton::populateRecordOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<RecordOpConversion>(typeConverter, targetInfo, benefit);
}
12 changes: 12 additions & 0 deletions third_party/proton/proton/language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from triton._C.libtriton import ir
from triton.language import core as tl
from triton.language.core import builtin
import warnings


@builtin
def record(isStart: bool, regionId: int, _builder=None):
warnings.warn(
"\nWarning the proton language module within Proton contains under development features that are not intended to be used outside of the core development team"
)
return tl.tensor(_builder.create_proton_record(isStart, regionId), tl.void)
41 changes: 41 additions & 0 deletions third_party/proton/test/test_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import pytest
import pathlib

import triton
import triton.language as tl
import triton.profiler.language as pl


def test_proton_record(tmp_path: pathlib.Path):

@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
pl.record(True, 0)
y = tl.load(y_ptr + offsets, mask=mask)
pl.record(False, 0)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)

torch.manual_seed(0)
size = 2**12
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output = torch.empty_like(x)
n_elements = output.numel()
grid = (1, 1, 1)
pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
ttir = pgm.asm['ttir']
assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir
assert "proton.record() {isStart = false, regionId = 0 : i32}" in ttir

0 comments on commit dcad5ac

Please sign in to comment.