diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index 1b629ba1639f..fb5f7156f9aa 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -10,6 +10,7 @@ add_triton_library(TritonToTritonGPU MLIRPass MLIRTransforms TritonIR + ProtonIR TritonGPUIR TritonGPUTransforms ) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 67ab63beb736..5159890468a8 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -16,6 +16,8 @@ #include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" + namespace { using namespace mlir; @@ -555,7 +557,17 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, TritonFuncOpPattern>(typeConverter, context); } - +// Proton patterns +// NOTE: Because Proton's inputs are scalars and not tensors this conversion +// isn't strictly nessessary however you could envision a case where we pass in +// tensors in for Triton object specific tracing operations in which case we +// would need to fill in the OpConversionPattern +void populateProtonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add>(typeConverter, + context); +} // // SCF patterns // @@ -770,6 +782,7 @@ class ConvertTritonToTritonGPU populateArithPatternsAndLegality(typeConverter, patterns, target); populateMathPatternsAndLegality(typeConverter, patterns, target); populateTritonPatterns(typeConverter, patterns, numCTAs); + populateProtonPatterns(typeConverter, patterns); // TODO: can we use // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); diff --git a/python/src/ir.cc b/python/src/ir.cc index 53ba39ae1026..531f1444ec46 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -31,6 +31,8 @@ #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/SourceMgr.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" + namespace { namespace py = pybind11; @@ -235,7 +237,8 @@ void init_triton_ir(py::module &&m) { registry.insert(); + ::mlir::triton::proton::ProtonDialect, LLVM::LLVMDialect, + mlir::ub::UBDialect>(); mlir::LLVM::registerInlinerInterface(registry); registerBuiltinDialectTranslation(registry); registerLLVMDialectTranslation(registry); @@ -1654,6 +1657,11 @@ void init_triton_ir(py::module &&m) { std::vector &tensorShape) -> Value { return self.create(base, shape, strides, tensorShape); + }) + // Proton Ops + .def("create_proton_record", + [](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void { + self.create(isStart, regionId); }); py::class_(m, "pass_manager", py::module_local()) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index e2465f17b622..d101035a4ba4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -29,4 +29,5 @@ add_triton_library(TritonAMDGPUToLLVM LINK_LIBS PUBLIC TritonGPUToLLVM TritonAMDGPUIR + TritonProtonToLLVM ) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 0e29b0c00d2b..10534115fd35 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -24,6 +24,8 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" + namespace mlir::triton { #define GEN_PASS_DEF_CONVERTTRITONAMDGPUTOLLVM #include "TritonAMDGPUToLLVM/Passes.h.inc" @@ -228,6 +230,10 @@ struct ConvertTritonAMDGPUToLLVM patterns); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); + + mlir::triton::proton::populateRecordOpToLLVMPattern( + typeConverter, patterns, targetInfo, commonBenefit); + mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index 96727b357106..a3d8a87290fb 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -25,4 +25,5 @@ add_triton_library(TritonNVIDIAGPUToLLVM LINK_LIBS PUBLIC TritonGPUToLLVM + TritonProtonToLLVM ) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 089e4aaebb2b..cb976e8ec4a0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -22,6 +22,8 @@ #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" + namespace mlir { namespace triton { #define GEN_PASS_DEF_CONVERTTRITONGPUTOLLVM @@ -149,6 +151,8 @@ struct ConvertTritonGPUToLLVM targetInfo, benefit); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); + mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns, diff --git a/third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h b/third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h new file mode 100644 index 000000000000..47d9f4bf5a6d --- /dev/null +++ b/third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h @@ -0,0 +1,20 @@ +#ifndef TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton { +namespace proton { +void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +} // namespace proton +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/proton/dialect/lib/CMakeLists.txt b/third_party/proton/dialect/lib/CMakeLists.txt index 0ca0f41c5af4..a224fd6f21f4 100644 --- a/third_party/proton/dialect/lib/CMakeLists.txt +++ b/third_party/proton/dialect/lib/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(Dialect) +add_subdirectory(TritonProtonToLLVM) diff --git a/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt b/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..84b134fda39d --- /dev/null +++ b/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt @@ -0,0 +1,6 @@ +add_triton_library(TritonProtonToLLVM + RecordOpToLLVM.cpp + + LINK_LIBS PUBLIC + ProtonIR +) diff --git a/third_party/proton/dialect/lib/TritonProtonToLLVM/RecordOpToLLVM.cpp b/third_party/proton/dialect/lib/TritonProtonToLLVM/RecordOpToLLVM.cpp new file mode 100644 index 000000000000..9b0b08ed730b --- /dev/null +++ b/third_party/proton/dialect/lib/TritonProtonToLLVM/RecordOpToLLVM.cpp @@ -0,0 +1,41 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" +#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" + +namespace { + +struct RecordOpConversion + : public ConvertOpToLLVMPattern { + explicit RecordOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::RecordOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.eraseOp(op); + return success(); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::proton::populateRecordOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/proton/proton/language.py b/third_party/proton/proton/language.py new file mode 100644 index 000000000000..d923f60c6a01 --- /dev/null +++ b/third_party/proton/proton/language.py @@ -0,0 +1,12 @@ +from triton._C.libtriton import ir +from triton.language import core as tl +from triton.language.core import builtin +import warnings + + +@builtin +def record(isStart: bool, regionId: int, _builder=None): + warnings.warn( + "\nWarning the proton language module within Proton contains under development features that are not intended to be used outside of the core development team" + ) + return tl.tensor(_builder.create_proton_record(isStart, regionId), tl.void) diff --git a/third_party/proton/test/test_record.py b/third_party/proton/test/test_record.py new file mode 100644 index 000000000000..0c623c3784ed --- /dev/null +++ b/third_party/proton/test/test_record.py @@ -0,0 +1,41 @@ +import torch +import pytest +import pathlib + +import triton +import triton.language as tl +import triton.profiler.language as pl + + +def test_proton_record(tmp_path: pathlib.Path): + + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + pl.record(True, 0) + y = tl.load(y_ptr + offsets, mask=mask) + pl.record(False, 0) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + torch.manual_seed(0) + size = 2**12 + x = torch.rand(size, device='cuda') + y = torch.rand(size, device='cuda') + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + ttir = pgm.asm['ttir'] + assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir + assert "proton.record() {isStart = false, regionId = 0 : i32}" in ttir