Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 1 addition & 27 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,37 +792,11 @@ void populateCFPatterns(TritonGPUTypeConverter &typeConverter,
patterns.add<CFCondBranchPattern, CFBranchPattern>(typeConverter, context);
}

// begin flagtree tle
class TleDSLRegionOpPattern : public OpConversionPattern<tle::DSLRegionOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(tle::DSLRegionOp op, tle::DSLRegionOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.cloneWithoutRegions<tle::DSLRegionOp>(op);
Region &body = op.getBody(), &newBody = newOp.getBody();
rewriter.inlineRegionBefore(body, newBody, newBody.end());

if (failed(rewriter.convertRegionTypes(&newBody, *getTypeConverter()))) {
return rewriter.notifyMatchFailure(op, "could not convert body types");
}
newOp->setOperands(adaptor.getOperands());
for (OpResult result : newOp->getResults()) {
result.setType(getTypeConverter()->convertType(result.getType()));
}

rewriter.replaceOp(op, newOp->getResults());
return success();
}
};

void populateTleRawPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns
.add<TleDSLRegionOpPattern, GenericOpPattern<tle::LocalPointersOp>,
GenericOpPattern<tle::YieldOp>,
.add<GenericOpPattern<tle::LocalPointersOp>,
GenericOpPattern<tle::ExtractAllocatedPtrOp>,
GenericOpPattern<tle::ExtractAlignedPtrOp>,
GenericOpPattern<tle::ExtractOffsetOp>,
Expand Down
2 changes: 0 additions & 2 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,6 @@ def make_llir(self, src, metadata, options, capability):
passes.llvmir.add_di_scope(pm)
if CUDABackend.instrumentation:
CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
# flagtree tle raw
tle.raw_passes.add_tle_dsl_region_inline(pm)

pm.run(mod)
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "tle/dialect/include/Conversion/TleToLLVM/DSLRegionOpToLLVM.h"
#include "tle/dialect/include/Conversion/TleToLLVM/ExtractOpToLLVM.h"
// begin flagtree tle
#include "tle/dialect/include/Conversion/TleToLLVM/LocalPointersOpToLLVM.h"
Expand Down Expand Up @@ -97,14 +96,6 @@ class TleLLVMConversionTarget : public ConversionTarget {
// begin flagtree tle
addLegalOp<mlir::UnrealizedConversionCastOp>();
// end flagtree tle
addDynamicallyLegalOp<tle::DSLRegionOp, tle::YieldOp>(
[&](Operation *op) -> bool {
bool hasLegalRegions = true;
for (auto &region : op->getRegions()) {
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(&region);
}
return hasLegalRegions && typeConverter.isLegal(op);
});
// Allow non-TLE ops to remain during this partial conversion.
markUnknownOpDynamicallyLegal([](Operation *) -> bool { return true; });
}
Expand Down Expand Up @@ -156,8 +147,6 @@ struct ConvertTritonGPUToLLVM
{
TleLLVMConversionTarget target(*context, typeConverter);
RewritePatternSet patterns(context);
mlir::triton::tle::populateDSLRegionOpToLLVMPatterns(typeConverter,
patterns, benefit);
mlir::triton::tle::populateExtractOpToLLVMPatterns(typeConverter,
patterns, benefit);
mlir::triton::tle::populatePackOpToLLVMPatterns(typeConverter, patterns,
Expand Down

This file was deleted.

12 changes: 0 additions & 12 deletions third_party/tle/dialect/include/IR/TleOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,6 @@ def Tle_LocalPointersOp : Tle_Op<"local_pointers",
let hasVerifier = 1;
}

def Tle_DSLRegionOp : Tle_Op<"dsl_region", [IsolatedFromAbove, MemDescViewTrait, RecursiveMemoryEffects]> {
let arguments = (ins Variadic<Tle_ArgType>:$inputs);
let results = (outs Variadic<Tle_ArgType>:$outputs);
let regions = (region AnyRegion:$body);
let hasVerifier = 1;
}

def Tle_YieldOp : Tle_Op<"yield", [Pure, Terminator, ReturnLike, HasParent<"DSLRegionOp">]> {
let arguments = (ins Variadic<Tle_ArgType>:$inputs);
let assemblyFormat = "attr-dict ($inputs^ `:` type($inputs))?";
}

def Tle_ExtractAllocatedPtrOp : Tle_Op<"extract_allocated_ptr", [Pure]> {
let arguments = (ins Tle_ArgType:$input);
let results = (outs LLVMPointerType:$ptr);
Expand Down
10 changes: 0 additions & 10 deletions third_party/tle/dialect/include/Transforms/DSLRegionInline.h

This file was deleted.

2 changes: 0 additions & 2 deletions third_party/tle/dialect/include/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,4 @@ def TritonTleLowerTmaCopy : Pass</*cli-arg*/"triton-tle-lower-tma-copy", /*Op*/"

def TleConvertArgToMemDesc : Pass<"tle-convert-arg-to-memdesc", "mlir::ModuleOp"> {}

def TleDSLRegionInline : Pass<"tle-dslregion-inline", "mlir::ModuleOp"> {}

#endif // TRITON_TLE_PASSES
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
add_triton_library(TleToLLVM
DSLRegionOpToLLVM.cpp
ExtractOpToLLVM.cpp
LocalPointersOpToLLVM.cpp
PackOpToLLVM.cpp
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ PackOpConversion::PackOpConversion(LLVMTypeConverter &typeConverter,
LogicalResult
PackOpConversion::matchAndRewrite(tle::PackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto regionOp = op->getParentOfType<tle::DSLRegionOp>();
if (ttg::MemDescType memdesc =
dyn_cast<ttg::MemDescType>(op.getOutput().getType())) {
LLVM::LLVMStructType llvmStructType =
Expand Down
19 changes: 0 additions & 19 deletions third_party/tle/dialect/lib/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,6 @@ namespace {
constexpr int kSharedMemoryAddressSpace = 3;
} // namespace

LogicalResult DSLRegionOp::verify() {
Region &body = getBody();
const uint32_t numArguments = body.getNumArguments(),
numOperands = getNumOperands();
if (numArguments != numOperands) {
return emitOpError() << "expects number of operands (" << numArguments
<< ") to match number of region arguments ("
<< numOperands << ")";
}
for (auto [arg, operand] : llvm::zip(body.getArguments(), getOperands())) {
if (arg.getType() != operand.getType()) {
return emitOpError() << "expects region argument type (" << arg.getType()
<< ") to match operand type (" << operand.getType()
<< ")";
}
}
return success();
}

void ExtractSizesOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, size_t num,
Value tensor) {
Expand Down
1 change: 0 additions & 1 deletion third_party/tle/dialect/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ add_triton_library(TritonTLETransforms
TleLowerAsyncLoad.cpp
TleLowerTmaCopy.cpp
ConvertArgToMemDesc.cpp
DSLRegionInline.cpp

DEPENDS
TritonTLETransformsIncGen
Expand Down
31 changes: 14 additions & 17 deletions third_party/tle/dialect/lib/Transforms/ConvertArgToMemDesc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ bool rewriteOne(Operation *toReplace, mlir::IRMapping &mapper,
mapper.lookup(ex.getInput()));
rewriter.replaceOp(ex, newEx->getResults());
return true;
} else {
return false;
}
return false;
}
template <typename ExtractOpT>
bool mapInputTensorOnce(Operation *toReplace,
llvm::SmallDenseSet<Value> &mappedValues) {
if (auto ex = llvm::dyn_cast<ExtractOpT>(toReplace)) {
if (auto tensorTy =
llvm::dyn_cast<RankedTensorType>(ex.getInput().getType())) {
mappedValues.insert(ex.getInput());
if (auto input = ex.getInput(); isa<RankedTensorType>(input.getType())) {
mappedValues.insert(input);
return true;
}
}
Expand Down Expand Up @@ -76,11 +76,11 @@ ttg::MemDescType getPlainMemDesc(RankedTensorType ty) {
true);
}

struct TleArgConversion : public OpRewritePattern<tle::DSLRegionOp> {
struct TleArgConversion : public OpRewritePattern<LLVM::CallOp> {
using OpRewritePattern::OpRewritePattern;

TleArgConversion(MLIRContext *context);
LogicalResult matchAndRewrite(tle::DSLRegionOp op,
LogicalResult matchAndRewrite(LLVM::CallOp op,
PatternRewriter &rewriter) const override;
};

Expand All @@ -95,27 +95,24 @@ TleArgConversion::TleArgConversion(MLIRContext *context)
: OpRewritePattern(context) {}

LogicalResult
TleArgConversion::matchAndRewrite(tle::DSLRegionOp op,
TleArgConversion::matchAndRewrite(LLVM::CallOp op,
PatternRewriter &rewriter) const {
SmallVector<Value> inputs(op.getInputs().begin(), op.getInputs().end());
SmallVector<Value> outputs(op.getOutputs().begin(), op.getOutputs().end());
SmallVector<Value> operands = op.getOperands();
PatternRewriter::InsertionGuard guard(rewriter);
SmallVector<Value> operands =
llvm::to_vector(llvm::concat<Value>(outputs, inputs));
bool hasConversion = false;
IRMapping mapper;
SmallVector<Operation *> targets;
SmallVector<ttg::LocalAllocOp> toDeallocOps;
llvm::SmallDenseSet<Value> mappedValues;
for (Value dslValue : operands) {
Operation *defOp = dslValue.getDefiningOp();
if (!defOp)
for (Value operand : operands) {
Operation *defOp = operand.getDefiningOp();
if (!defOp) {
continue;
hasConversion =
}
hasConversion |=
mapInputTensors<tle::ExtractAllocatedPtrOp, tle::ExtractSizesOp,
tle::ExtractStridesOp, tle::ExtractOffsetOp,
tle::ExtractAlignedPtrOp>(defOp, mappedValues) ||
hasConversion;
tle::ExtractAlignedPtrOp>(defOp, mappedValues);
if (isa<tle::ExtractAllocatedPtrOp, tle::ExtractSizesOp,
tle::ExtractStridesOp, tle::ExtractOffsetOp,
tle::ExtractAlignedPtrOp>(defOp)) {
Expand Down
Loading