Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2228,7 +2228,9 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
`(` $cond `,` $lhs`,` $rhs `)` `:` qualified(type($cond)) `,`
qualified(type($lhs)) attr-dict
}];

let hasVerifier = 1;
let hasFolder = 1;
}

#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
35 changes: 35 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,41 @@ LogicalResult cir::VecTernaryOp::verify() {
return success();
}

OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
mlir::Attribute cond = adaptor.getCond();
mlir::Attribute lhs = adaptor.getLhs();
mlir::Attribute rhs = adaptor.getRhs();

if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) &&
mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) &&
mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) &&
mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) &&
mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) {
if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
return {};

auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);

mlir::ArrayAttr condElts = condVec.getElts();

SmallVector<mlir::Attribute, 16> elements;
elements.reserve(condElts.size());

for (const auto &[idx, condAttr] :
llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
if (condAttr.getSInt()) {
elements.push_back(lhsVec.getElts()[idx]);
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be more natural as an if-else rather than using continue.

}

elements.push_back(rhsVec.getElts()[idx]);
}

cir::VectorType vecTy = getLhs().getType();
return cir::ConstVectorAttr::get(
vecTy, mlir::ArrayAttr::get(getContext(), elements));
}

return {};
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ void CIRCanonicalizePass::runOnOperation() {
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
// CastOp, UnaryOp, VecExtractOp and VecShuffleDynamicOp are here to perform
// a manual `fold` in applyOpPatternsGreedily.
// CastOp, UnaryOp, VecExtractOp, VecShuffleDynamicOp and VecTernaryOp are
// here to perform a manual `fold` in applyOpPatternsGreedily.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amongst the operations that are here to perform a manual fold are.... (This is Monty Python joke, but we really should just make this a general comment that many of the operations in this list are just here so we can fold them.)

Copy link
Member Author

@AmrDeveloper AmrDeveloper Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can split them into two if statements, one for the ops that are here to perform manual fold and another for the other ops

if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, SelectOp>(op))
   ops.push_back(op);
   
// Operations to perform manual `fold` in applyOpPatternsGreedily.
if (isa<CastOp, UnaryOp, VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(
            op))
  ops.push_back(op);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not really necessary. I think we've passed the point where anyone is going to be checking each one to see why it's here. Also, some operations may be handled explicitly and also have a folder.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, i will convert it to general comment

if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
VecExtractOp, VecShuffleDynamicOp>(op))
VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op))
ops.push_back(op);
});

Expand Down
20 changes: 20 additions & 0 deletions clang/test/CIR/Transforms/vector-ternary-fold.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s

!s32i = !cir.int<s, 32>

module {
cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> {
%cond = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
%lhs = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
%rhs = cir.const #cir.const_vector<[#cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
%res = cir.vec.ternary(%cond, %lhs, %rhs) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
cir.return %res : !cir.vector<4 x !s32i>
}

// [1, 0, 1, 0] ? [1, 2, 3, 4] : [5, 6, 7, 8] Will be fold to [1, 6, 3, 8]
// CHECK: cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> {
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<6> : !s32i, #cir.int<3> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}


Loading