Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite ReduceOp to support arbitrary reduce operations #1305

Merged
merged 23 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5f6e3de
Add GenericReduceOp to ttir and add `tl.prod` using it
peterbell10 Mar 7, 2023
786433e
Lower tt.generic_reduce to LLVM IR
peterbell10 Mar 8, 2023
ab461e8
Support simultaneous reduction of multiple tensors
peterbell10 Mar 13, 2023
0aba718
Automatically build reduction combine op region from JITFunction
peterbell10 Mar 14, 2023
08c2e35
Replace old ReduceOp entirely
peterbell10 Mar 14, 2023
4b74ce3
Misc cleanup
peterbell10 Mar 14, 2023
4b2b16a
Add SameOperandsEncoding
peterbell10 Mar 16, 2023
b3957a7
Run clang-format
peterbell10 Mar 16, 2023
bbec2fe
Fix lit tests
peterbell10 Mar 20, 2023
7e195a6
Update to newer LLVM
peterbell10 Mar 20, 2023
0f7a528
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Mar 23, 2023
19d490b
Lint
peterbell10 Mar 23, 2023
c5d928b
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Mar 30, 2023
c6f777b
Fix merge conflicts
peterbell10 Mar 30, 2023
440a39d
Respond to some review comments
peterbell10 Apr 4, 2023
3db5241
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Apr 4, 2023
c404afe
Merge remote-tracking branch 'upstream/main' into HEAD
peterbell10 Apr 7, 2023
19d31c6
Don't rematerialize ReduceOp
peterbell10 Apr 10, 2023
a6ae9e7
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Apr 10, 2023
0cd8f0f
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Apr 11, 2023
c7c8ac1
Revert "Don't rematerialize ReduceOp"
peterbell10 Apr 12, 2023
768241d
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Apr 12, 2023
04c2164
Merge branch 'main' into generic-reduction
ptillet Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Respond to some review comments
  • Loading branch information
peterbell10 committed Apr 4, 2023
commit 440a39d8a85e8c2d25d0ade21b6e83df0ff63183
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
let description = [{}];

let arguments = (ins TT_BoolLike:$condition,
TT_Type:$true_value,
TT_Type:$false_value);
TT_Tensor:$true_value,
TT_Tensor:$false_value);

let results = (outs TT_Type:$result);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void MembarAnalysis::visitTerminator(Operation *op,
if (isa<triton::ReduceReturnOp>(op) || isa<func::ReturnOp>(op)) {
Jokeren marked this conversation as resolved.
Show resolved Hide resolved
return;
}
op->emitOpError("Unknown terminator encountered in membar analysis");
llvm_unreachable("Unknown terminator encountered in membar analysis");
}

void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
elems = product<unsigned>(smemShape);
}

unsigned bytes_per_elem = 0;
unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
bytes_per_elem += ty.getIntOrFloatBitWidth() / 8;
bytesPerElem += ty.getIntOrFloatBitWidth() / 8;
}
return bytes_per_elem * elems;
return bytesPerElem * elems;
}

bool isSharedEncoding(Value value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class SimplifyReduceCvt : public mlir::RewritePattern {
continue;
}

// TODO: This always takes layout from the first argument which
// is fine for argmin/argmax but may not be optimal generally
// TODO: This only moves conversions from the first argument which is
// fine for argmin/argmax but may not be optimal generally
if (convert.getResult() != owner.getOperands()[0]) {
continue;
}
Expand All @@ -123,8 +123,8 @@ class SimplifyReduceCvt : public mlir::RewritePattern {
auto newEncoding =
newOperands[0].getType().cast<RankedTensorType>().getEncoding();

if (!newEncoding.isa<triton::gpu::BlockedEncodingAttr>()) {
// ReduceOpToLLVM requires block encoding
// this may generate unsupported conversions in the LLVM codegen
if (newEncoding.isa<triton::gpu::MmaEncodingAttr>()) {
return failure();
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
// We have custom versions of some arith operators
addIllegalOp<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>();
addIllegalOp<arith::CmpIOp, arith::CmpFOp>();

addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
func::FuncDialect, triton::TritonDialect,
Expand Down