Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
predList.emplace_back(pos, builder.getIsNotNull());

if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) {
// If the attribute has a type or value, add a constraint.
if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind,
return false;
Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy,
state.builder, value.getLoc());
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(value.getDefiningOp()))
if (auto constOp = value.getDefiningOp<arith::ConstantOp>())
return constOp.getValue() == valueAttr;
return false;
}
Expand Down Expand Up @@ -1865,7 +1865,6 @@ verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) {
return success();
}


/// External utility to vectorize affine loops in 'loops' using the n-D
/// vectorization factors in 'vectorSizes'. By default, each vectorization
/// factor is applied inner-to-outer to the loops of each loop nest.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2498,7 +2498,7 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
matchPattern(adaptor.getFalseValue(), m_Zero()))
return condition;

if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
auto pred = cmp.getPredicate();
if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
auto cmpLhs = cmp.getLhs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ std::optional<Value> getExtOperand(Value v) {

// If the operand is not defined by an explicit extend operation of the
// accepted operation type allow for an implicit sign-extension.
auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
auto extOp = v.getDefiningOp<Op>();
if (!extOp) {
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
auto eltTy = cast<VectorType>(v.getType()).getElementType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::optional<Value> getExtOperand(Value v) {

// If the operand is not defined by an explicit extend operation of the
// accepted operation type allow for an implicit sign-extension.
auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
auto extOp = v.getDefiningOp<Op>();
if (!extOp) {
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
auto vTy = cast<VectorType>(v.getType());
Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
continue;

for (Value operand : op.getOperands()) {
auto usedExpression =
dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp());

auto usedExpression = operand.getDefiningOp<ExpressionOp>();
if (!usedExpression)
continue;

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2707,7 +2707,7 @@ LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
while (alias) {
Block &initBlock = alias.getInitializerBlock();
auto returnOp = cast<ReturnOp>(initBlock.getTerminator());
auto addrOp = dyn_cast<AddressOfOp>(returnOp.getArg().getDefiningOp());
auto addrOp = returnOp.getArg().getDefiningOp<AddressOfOp>();
// FIXME: This is a best effort solution. The AliasOp body might be more
// complex and in that case we bail out with success. To completely match
// the LLVM IR logic it would be necessary to implement proper alias and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1851,7 +1851,7 @@ transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
assert(!packOp && "packOp must be null on entry when unPackOp is not null");
OpOperand *packUse = linalgOp.getDpsInitOperand(
cast<OpResult>(unPackOp.getSource()).getResultNumber());
packOp = dyn_cast_or_null<linalg::PackOp>(packUse->get().getDefiningOp());
packOp = packUse->get().getDefiningOp<linalg::PackOp>();
if (!packOp || !packOp.getResult().hasOneUse())
return emitSilenceableError() << "could not find matching pack op";
}
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,7 @@ static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp,
Value source = extractSliceOp.getSource();
LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");
while (source && source != expectedSource) {
auto destOp =
dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp());
auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
if (!destOp)
break;
LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
Value source = transferRead.getBase();

// Skip view-like Ops and retrive the actual soruce Operation
while (auto srcOp =
dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp()))
while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>())
source = srcOp.getViewSource();

llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const {
MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}

MeshSharding::MeshSharding(Value rhs) {
auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
auto shardingOp = rhs.getDefiningOp<ShardingOp>();
assert(shardingOp && "expected sharding op");
auto splitAxes = shardingOp.getSplitAxes().getAxes();
// If splitAxes are empty, use "empty" constructor.
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ ReshardingRquirementKind getReshardingRquirementKind(

for (auto [operand, sharding] :
llvm::zip_equal(op->getOperands(), operandShardings)) {
ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
ShardOp shardOp = operand.getDefiningOp<ShardOp>();
if (!shardOp) {
continue;
}
Expand Down Expand Up @@ -376,8 +376,7 @@ struct ShardingPropagation
LLVM_DEBUG(
DBGS() << "print all the ops' iterator types and indexing maps in the "
"block.\n";
for (Operation &op
: block.getOperations()) {
for (Operation &op : block.getOperations()) {
if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,7 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,

// Check if 2 shard ops are chained. If not there is no need for resharding
// as the source and target shared the same sharding.
ShardOp srcShardOp =
dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>();
if (!srcShardOp) {
targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
} else {
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1730,8 +1730,7 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
if (!mapOp.getDefiningOp())
return emitError(op->getLoc(), "missing map operation");

if (auto mapInfoOp =
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
uint64_t mapTypeBits = mapInfoOp.getMapType();

bool to = mapTypeToBitFlag(
Expand Down
11 changes: 5 additions & 6 deletions mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
Value ptrLike;
FromPtrOp fromPtr = *this;
while (fromPtr != nullptr) {
auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
auto toPtr = fromPtr.getPtr().getDefiningOp<ToPtrOp>();
// Cannot fold if it's not a `to_ptr` op or the initial and final types are
// different.
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
Expand All @@ -64,13 +64,12 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
ptrLike = toPtr.getPtr();
} else if (md) {
// Fold if the metadata can be verified to be equal.
if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
if (auto mdOp = md.getDefiningOp<GetMetadataOp>();
mdOp && mdOp.getPtr() == toPtr.getPtr())
ptrLike = toPtr.getPtr();
}
// Check for a sequence of casts.
fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
: nullptr);
fromPtr = ptrLike ? ptrLike.getDefiningOp<FromPtrOp>() : nullptr;
}
return ptrLike;
}
Expand Down Expand Up @@ -112,13 +111,13 @@ OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
Value ptr;
ToPtrOp toPtr = *this;
while (toPtr != nullptr) {
auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
auto fromPtr = toPtr.getPtr().getDefiningOp<FromPtrOp>();
// Cannot fold if it's not a `from_ptr` op.
if (!fromPtr)
return ptr;
ptr = fromPtr.getPtr();
// Check for chains of casts.
toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
toPtr = ptr.getDefiningOp<ToPtrOp>();
}
return ptr;
}
Expand Down
7 changes: 3 additions & 4 deletions mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,10 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
op.getStep(), tileSizeConstants)) {
// Collect the statically known loop bounds
auto lowerBoundConstant =
dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
lowerBound.getDefiningOp<arith::ConstantIndexOp>();
auto upperBoundConstant =
dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
auto stepConstant =
dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
upperBound.getDefiningOp<arith::ConstantIndexOp>();
auto stepConstant = step.getDefiningOp<arith::ConstantIndexOp>();
auto tileSize =
cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value();
// If the loop bounds and the loop step are constant and if the number of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1317,7 +1317,7 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {

Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
Value newSize = arith::AddIOp::create(rewriter, loc, size, n);
auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
auto nValue = n.getDefiningOp<arith::ConstantIndexOp>();
bool nIsOne = (nValue && nValue.value() == 1);

if (!op.getInbounds()) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
Value input = op.getInput();

// Check the input to the CLAMP op is itself a CLAMP.
auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
auto clampOp = input.getDefiningOp<tosa::ClampOp>();
if (!clampOp)
return failure();

Expand Down Expand Up @@ -1634,7 +1634,7 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
for (Value operand : getOperands()) {
concatOperands.emplace_back(operand);

auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
auto producer = operand.getDefiningOp<ConcatOp>();
if (!producer)
continue;

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2591,8 +2591,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
llvm::enumerate(fromElements.getElements())) {

// Check that the element is from a vector.extract operation.
auto extractOp =
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
auto extractOp = element.getDefiningOp<vector::ExtractOp>();
if (!extractOp) {
return rewriter.notifyMatchFailure(fromElements,
"element not from vector.extract");
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
// inlined, and as such should be wrapped in parentheses in order to guarantee
// its precedence and associativity.
auto requiresParentheses = [&](Value value) {
auto expressionOp =
dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
auto expressionOp = value.getDefiningOp<ExpressionOp>();
if (!expressionOp)
return false;
return shouldBeInlined(expressionOp);
Expand Down Expand Up @@ -1545,7 +1544,7 @@ LogicalResult CppEmitter::emitOperand(Value value) {
return success();
}

auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
auto expressionOp = value.getDefiningOp<ExpressionOp>();
if (expressionOp && shouldBeInlined(expressionOp))
return emitExpression(expressionOp);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ processDataOperands(llvm::IRBuilderBase &builder,
// Copyin operands are handled as `to` call.
llvm::SmallVector<mlir::Value> create, copyin;
for (mlir::Value dataOp : op.getDataClauseOperands()) {
if (auto createOp =
mlir::dyn_cast_or_null<acc::CreateOp>(dataOp.getDefiningOp())) {
if (auto createOp = dataOp.getDefiningOp<acc::CreateOp>()) {
create.push_back(createOp.getVarPtr());
} else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
dataOp.getDefiningOp())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3541,8 +3541,7 @@ getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
}

static bool isDeclareTargetLink(mlir::Value value) {
if (auto addressOfOp =
llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
if (auto addressOfOp = value.getDefiningOp<LLVM::AddressOfOp>()) {
auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
if (auto declareTargetGlobal =
Expand Down Expand Up @@ -4502,8 +4501,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);

if (auto devId = dataOp.getDevice())
if (auto constOp =
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();

Expand All @@ -4520,8 +4518,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);

if (auto devId = enterDataOp.getDevice())
if (auto constOp =
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
RTLFn =
Expand All @@ -4540,8 +4537,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);

if (auto devId = exitDataOp.getDevice())
if (auto constOp =
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();

Expand All @@ -4560,8 +4556,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);

if (auto devId = updateDataOp.getDevice())
if (auto constOp =
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();

Expand Down Expand Up @@ -5202,8 +5197,7 @@ static std::optional<int64_t> extractConstInteger(Value value) {
if (!value)
return std::nullopt;

if (auto constOp =
dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
return constAttr.getInt();

Expand Down
3 changes: 1 addition & 2 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ struct FolderCommutativeOp2WithConstant

LogicalResult matchAndRewrite(TestCommutative2Op op,
PatternRewriter &rewriter) const override {
auto operand =
dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
auto operand = op->getOperand(0).getDefiningOp<TestCommutative2Op>();
if (!operand)
return failure();
Attribute constInput;
Expand Down