Skip to content

Commit

Permalink
[BACKEND] Fix common mistake of missing checks for null pointer (#4532)
Browse files Browse the repository at this point in the history
When using `getDefiningOp()` we are often missing to check whether the
result is null. Using `getDefiningOp<OpTy>()` helps avoid those latent
bugs.
  • Loading branch information
ThomasRaoux authored Aug 17, 2024
1 parent 2abfaec commit 6a5638e
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 32 deletions.
10 changes: 4 additions & 6 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1299,9 +1299,8 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
idxCol = urem(idxCol, numElemsPerSwizzlingRowVal);
strideRow = numElemsPerSwizzlingRowVal;
}
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxCol.getDefiningOp())) {
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
add.getRhs().getDefiningOp())) {
if (auto add = idxCol.getDefiningOp<LLVM::AddOp>()) {
if (auto _cst = add.getRhs().getDefiningOp<LLVM::ConstantOp>()) {
unsigned cst =
cast<IntegerAttr>(_cst.getValue()).getValue().getSExtValue();
unsigned key = cst % (outVec * maxPhase);
Expand All @@ -1310,9 +1309,8 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase);
}
}
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxRow.getDefiningOp())) {
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
add.getRhs().getDefiningOp())) {
if (auto add = idxRow.getDefiningOp<LLVM::AddOp>()) {
if (auto _cst = add.getRhs().getDefiningOp<LLVM::ConstantOp>()) {
unsigned cst =
mlir::cast<IntegerAttr>(_cst.getValue()).getValue().getSExtValue();
unsigned key = cst % (perPhase * maxPhase);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
// Manually constant-fold the layout where possible.
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
for (auto [inDimName, idx] : indices) {
if (auto constant = dyn_cast<LLVM::ConstantOp>(idx.getDefiningOp())) {
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
constantIns.push_back(
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
} else {
Expand All @@ -184,7 +184,7 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
}

for (auto [inDimName, idx] : indices) {
if (isa<LLVM::ConstantOp>(idx.getDefiningOp())) {
if (idx.getDefiningOp<LLVM::ConstantOp>()) {
continue;
}

Expand Down
6 changes: 2 additions & 4 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ struct CanonicalizeMaskedLoadPattern : public OpRewritePattern<LoadOp> {
if (!mask)
return failure();

auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
auto constantMask = mask.getDefiningOp<arith::ConstantOp>();
if (!constantMask)
return failure();

Expand Down Expand Up @@ -159,8 +158,7 @@ struct CanonicalizeMaskedStorePattern : public OpRewritePattern<StoreOp> {
if (!mask)
return failure();

auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
auto constantMask = mask.getDefiningOp<arith::ConstantOp>();
if (!constantMask)
return failure();

Expand Down
21 changes: 7 additions & 14 deletions lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,15 @@ class CombineSelectMaskedLoadPattern : public RewritePattern {
Value falseValue = selectOp.getFalseValue();
Value condSelect = selectOp.getCondition();

auto *loadOpCandidate = trueValue.getDefiningOp();
auto loadOp = llvm::dyn_cast_or_null<LoadOp>(loadOpCandidate);
auto loadOp = trueValue.getDefiningOp<LoadOp>();
if (!loadOp)
return failure();

Value mask = loadOp.getMask();
if (!mask)
return failure();

auto *splatOpCandidate = mask.getDefiningOp();
auto splatOp = llvm::dyn_cast_or_null<SplatOp>(splatOpCandidate);
auto splatOp = mask.getDefiningOp<SplatOp>();
if (!splatOp)
return failure();

Expand Down Expand Up @@ -175,26 +173,21 @@ class CombineBroadcastMulReducePattern : public RewritePattern {
if (!isReduceAdd)
return failure();
// operand of reduce has to be mul
auto mulOp = llvm::dyn_cast_or_null<arith::MulFOp>(
reduceOp.getOperand(0).getDefiningOp());
auto mulOp = reduceOp.getOperand(0).getDefiningOp<arith::MulFOp>();
if (!mulOp)
return failure();
// mul operand has to be broadcast
auto broadcastLhsOp = llvm::dyn_cast_or_null<BroadcastOp>(
mulOp.getOperand(0).getDefiningOp());
auto broadcastLhsOp = mulOp.getOperand(0).getDefiningOp<BroadcastOp>();
if (!broadcastLhsOp)
return failure();
auto broadcastRhsOp = llvm::dyn_cast_or_null<BroadcastOp>(
mulOp.getOperand(1).getDefiningOp());
auto broadcastRhsOp = mulOp.getOperand(1).getDefiningOp<BroadcastOp>();
if (!broadcastRhsOp)
return failure();
// broadcast operand is expand dims
auto expandLhsOp = llvm::dyn_cast_or_null<ExpandDimsOp>(
broadcastLhsOp.getSrc().getDefiningOp());
auto expandLhsOp = broadcastLhsOp.getSrc().getDefiningOp<ExpandDimsOp>();
if (!expandLhsOp)
return failure();
auto expandRhsOp = llvm::dyn_cast_or_null<ExpandDimsOp>(
broadcastRhsOp.getSrc().getDefiningOp());
auto expandRhsOp = broadcastRhsOp.getSrc().getDefiningOp<ExpandDimsOp>();
if (!expandRhsOp)
return failure();
// get not-broadcast dimensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class TritonGPUOptimizeThreadLocalityPass
return;
auto argNum = yieldOpOperand.getOperandNumber();
auto oldAccum = forOp.getInitArgs()[argNum];
auto cstOp = dyn_cast<arith::ConstantOp>(oldAccum.getDefiningOp());
auto cstOp = oldAccum.getDefiningOp<arith::ConstantOp>();
if (!cstOp)
return;
reduceOps.insert(reduce);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1397,16 +1397,18 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
while (isa_and_nonnull<ttg::ConvertLayoutOp, tt::TransOp>(
transitiveOperand.getDefiningOp()) ||
isa<BlockArgument>(transitiveOperand)) {
if (auto blockArg = dyn_cast<BlockArgument>(transitiveOperand)) {
assert(blockArg.getOwner() == forOp.getBody());
auto blockArg = dyn_cast<BlockArgument>(transitiveOperand);
if (blockArg && blockArg.getOwner() == forOp.getBody()) {
transitiveOperand =
cast<scf::YieldOp>(blockArg.getOwner()->getTerminator())
.getOperand(blockArg.getArgNumber() - 1);
}
transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0);
if (Operation *def = transitiveOperand.getDefiningOp()) {
transitiveOperand = def->getOperand(0);
}
}
return forOp.isDefinedOutsideOfLoop(transitiveOperand) ||
isa<ttg::MemDescSubviewOp>(transitiveOperand.getDefiningOp());
transitiveOperand.getDefiningOp<ttg::MemDescSubviewOp>();
};

// We don't have to call checkOperand on getC() because it's always in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class BypassEpilogueSMEM : public mlir::RewritePattern {
chainedOps.push_back(chainedOp);
}

auto cvtOp = dyn_cast<triton::gpu::ConvertLayoutOp>(val.getDefiningOp());
auto cvtOp = val.getDefiningOp<triton::gpu::ConvertLayoutOp>();
if (!cvtOp)
return mlir::failure();

Expand Down

0 comments on commit 6a5638e

Please sign in to comment.