-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Proton][Dialect] Add Initial Frontend and Target Backend Infrastruct…
…ure For Proton Dialect (#5506) Implement initial basic infrastructure for the Proton Dialect added in #5119 This PR extends the initial boilerplate MLIR Dialect code to the Triton frontend and target backends - currently just lowered to a no-op.
- Loading branch information
Showing
13 changed files
with
157 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,4 +29,5 @@ add_triton_library(TritonAMDGPUToLLVM | |
LINK_LIBS PUBLIC | ||
TritonGPUToLLVM | ||
TritonAMDGPUIR | ||
TritonProtonToLLVM | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,4 +25,5 @@ add_triton_library(TritonNVIDIAGPUToLLVM | |
|
||
LINK_LIBS PUBLIC | ||
TritonGPUToLLVM | ||
TritonProtonToLLVM | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
20 changes: 20 additions & 0 deletions
20
third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H | ||
#define TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H | ||
|
||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::triton; | ||
|
||
namespace mlir { | ||
namespace triton { | ||
namespace proton { | ||
void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, | ||
RewritePatternSet &patterns, | ||
const TargetInfoBase &targetInfo, | ||
PatternBenefit benefit); | ||
} // namespace proton | ||
} // namespace triton | ||
} // namespace mlir | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
add_subdirectory(Dialect) | ||
add_subdirectory(TritonProtonToLLVM) |
6 changes: 6 additions & 0 deletions
6
third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
add_triton_library(TritonProtonToLLVM | ||
RecordOpToLLVM.cpp | ||
|
||
LINK_LIBS PUBLIC | ||
ProtonIR | ||
) |
41 changes: 41 additions & 0 deletions
41
third_party/proton/dialect/lib/TritonProtonToLLVM/RecordOpToLLVM.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#include "mlir/Conversion/LLVMCommon/Pattern.h" | ||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" | ||
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" | ||
#include "triton/Conversion/TritonGPUToLLVM/Utility.h" | ||
#include "triton/Dialect/Triton/IR/Dialect.h" | ||
|
||
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" | ||
#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" | ||
|
||
namespace { | ||
|
||
struct RecordOpConversion | ||
: public ConvertOpToLLVMPattern<mlir::triton::proton::RecordOp> { | ||
explicit RecordOpConversion(LLVMTypeConverter &typeConverter, | ||
const TargetInfoBase &targetInfo, | ||
PatternBenefit benefit) | ||
: mlir::ConvertOpToLLVMPattern<mlir::triton::proton::RecordOp>( | ||
typeConverter, benefit), | ||
targetInfo(targetInfo) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(mlir::triton::proton::RecordOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
rewriter.eraseOp(op); | ||
return success(); | ||
} | ||
|
||
protected: | ||
const TargetInfoBase &targetInfo; | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::triton::proton::populateRecordOpToLLVMPattern( | ||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, | ||
const TargetInfoBase &targetInfo, PatternBenefit benefit) { | ||
patterns.add<RecordOpConversion>(typeConverter, targetInfo, benefit); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from triton._C.libtriton import ir | ||
from triton.language import core as tl | ||
from triton.language.core import builtin | ||
import warnings | ||
|
||
|
||
@builtin | ||
def record(isStart: bool, regionId: int, _builder=None): | ||
warnings.warn( | ||
"\nWarning the proton language module within Proton contains under development features that are not intended to be used outside of the core development team" | ||
) | ||
return tl.tensor(_builder.create_proton_record(isStart, regionId), tl.void) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch | ||
import pytest | ||
import pathlib | ||
|
||
import triton | ||
import triton.language as tl | ||
import triton.profiler.language as pl | ||
|
||
|
||
def test_proton_record(tmp_path: pathlib.Path): | ||
|
||
@triton.jit | ||
def add_kernel( | ||
x_ptr, | ||
y_ptr, | ||
output_ptr, | ||
n_elements, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
pid = tl.program_id(axis=0) | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = offsets < n_elements | ||
x = tl.load(x_ptr + offsets, mask=mask) | ||
pl.record(True, 0) | ||
y = tl.load(y_ptr + offsets, mask=mask) | ||
pl.record(False, 0) | ||
output = x + y | ||
tl.store(output_ptr + offsets, output, mask=mask) | ||
|
||
torch.manual_seed(0) | ||
size = 2**12 | ||
x = torch.rand(size, device='cuda') | ||
y = torch.rand(size, device='cuda') | ||
output = torch.empty_like(x) | ||
n_elements = output.numel() | ||
grid = (1, 1, 1) | ||
pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) | ||
ttir = pgm.asm['ttir'] | ||
assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir | ||
assert "proton.record() {isStart = false, regionId = 0 : i32}" in ttir |