Skip to content

[CIR] Upstream ShuffleOp for VectorType #142288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrConstraints.td
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
Expand Down Expand Up @@ -39,4 +38,11 @@ def CIR_AnyIntOrFloatAttr : AnyAttrOf<[CIR_AnyIntAttr, CIR_AnyFPAttr],
string cppType = "::mlir::TypedAttr";
}

#endif // CLANG_CIR_DIALECT_IR_CIRATTRCONSTRAINTS_TD
//===----------------------------------------------------------------------===//
// ArrayAttr constraints
//===----------------------------------------------------------------------===//

def CIR_IntArrayAttr : TypedArrayAttrBase<CIR_AnyIntAttr,
"integer array attribute">;

#endif // CLANG_CIR_DIALECT_IR_CIRATTRCONSTRAINTS_TD
47 changes: 47 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
include "clang/CIR/Dialect/IR/CIRDialect.td"
include "clang/CIR/Dialect/IR/CIRTypes.td"
include "clang/CIR/Dialect/IR/CIRAttrs.td"
include "clang/CIR/Dialect/IR/CIRAttrConstraints.td"

include "clang/CIR/Interfaces/CIROpInterfaces.td"
include "clang/CIR/Interfaces/CIRLoopOpInterface.td"
Expand Down Expand Up @@ -2155,6 +2156,52 @@ def VecCmpOp : CIR_Op<"vec.cmp", [Pure, SameTypeOperands]> {
}];
}

//===----------------------------------------------------------------------===//
// VecShuffleOp
//===----------------------------------------------------------------------===//

// TODO: Create an interface that both VecShuffleOp and VecShuffleDynamicOp
// implement. This could be useful for passes that don't care how the vector
// shuffle was specified.

def VecShuffleOp : CIR_Op<"vec.shuffle",
[Pure, AllTypesMatch<["vec1", "vec2"]>]> {
let summary = "Combine two vectors using indices passed as constant integers";
let description = [{
The `cir.vec.shuffle` operation implements the documented form of Clang's
`__builtin_shufflevector`, where the indices of the shuffled result are
integer constants.

The two input vectors, which must have the same type, are concatenated.
Each of the integer constant arguments is interpreted as an index into that
concatenated vector, with a value of -1 meaning that the result value
doesn't matter. The result vector, which must have the same element type as
the input vectors and the same number of elements as the list of integer
constant indices, is constructed by taking the elements at the given
indices from the concatenated vector. The size of the result vector does
not have to match the size of the individual input vectors or of the
concatenated vector.

```mlir
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<2 x !s32i>)
[#cir.int<3> : !s64i, #cir.int<1> : !s64i] : !cir.vector<2 x !s32i>
```
}];

let arguments = (ins
CIR_VectorType:$vec1,
CIR_VectorType:$vec2,
CIR_IntArrayAttr:$indices
);

let results = (outs CIR_VectorType:$result);
let assemblyFormat = [{
`(` $vec1 `,` $vec2 `:` qualified(type($vec1)) `)` $indices `:`
qualified(type($result)) attr-dict
}];
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// VecShuffleDynamicOp
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 18 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,24 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
cgf.getLoc(e->getSourceRange()), inputVec, indexVec);
}

cgf.getCIRGenModule().errorNYI(e->getSourceRange(),
"ShuffleVectorExpr with indices");
return {};
mlir::Value vec1 = Visit(e->getExpr(0));
mlir::Value vec2 = Visit(e->getExpr(1));

// The documented form of __builtin_shufflevector, where the indices are
// a variable number of integer constants. The constants will be stored
// in an ArrayAttr.
SmallVector<mlir::Attribute, 8> indices;
for (unsigned i = 2; i < e->getNumSubExprs(); ++i) {
indices.push_back(
cir::IntAttr::get(cgf.builder.getSInt64Ty(),
e->getExpr(i)
->EvaluateKnownConstInt(cgf.getContext())
.getSExtValue()));
}

return cgf.builder.create<cir::VecShuffleOp>(
cgf.getLoc(e->getSourceRange()), cgf.convertType(e->getType()), vec1,
vec2, cgf.builder.getArrayAttr(indices));
}

mlir::Value VisitConvertVectorExpr(ConvertVectorExpr *e) {
Expand Down
23 changes: 23 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,29 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
return elements[index];
}

//===----------------------------------------------------------------------===//
// VecShuffle
//===----------------------------------------------------------------------===//

LogicalResult cir::VecShuffleOp::verify() {
// The number of elements in the indices array must match the number of
// elements in the result type.
if (getIndices().size() != getResult().getType().getSize()) {
return emitOpError() << ": the number of elements in " << getIndices()
<< " and " << getResult().getType() << " don't match";
}

// The element types of the two input vectors and of the result type must
// match.
if (getVec1().getType().getElementType() !=
getResult().getType().getElementType()) {
return emitOpError() << ": element types of " << getVec1().getType()
<< " and " << getResult().getType() << " don't match";
}

return success();
}

//===----------------------------------------------------------------------===//
// VecShuffleDynamicOp
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 18 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecExtractOpLowering,
CIRToLLVMVecInsertOpLowering,
CIRToLLVMVecCmpOpLowering,
CIRToLLVMVecShuffleOpLowering,
CIRToLLVMVecShuffleDynamicOpLowering,
CIRToLLVMVecTernaryOpLowering
// clang-format on
Expand Down Expand Up @@ -1922,6 +1923,23 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMVecShuffleOpLowering::matchAndRewrite(
cir::VecShuffleOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
// LLVM::ShuffleVectorOp takes an ArrayRef of int for the list of indices.
// Convert the ClangIR ArrayAttr of IntAttr constants into a
// SmallVector<int>.
SmallVector<int, 8> indices;
std::transform(
op.getIndices().begin(), op.getIndices().end(),
std::back_inserter(indices), [](mlir::Attribute intAttr) {
return mlir::cast<cir::IntAttr>(intAttr).getValue().getSExtValue();
});
rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(
op, adaptor.getVec1(), adaptor.getVec2(), indices);
return mlir::success();
}

mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
cir::VecShuffleDynamicOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,16 @@ class CIRToLLVMVecCmpOpLowering
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMVecShuffleOpLowering
: public mlir::OpConversionPattern<cir::VecShuffleOp> {
public:
using mlir::OpConversionPattern<cir::VecShuffleOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::VecShuffleOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMVecShuffleDynamicOpLowering
: public mlir::OpConversionPattern<cir::VecShuffleDynamicOp> {
public:
Expand Down
25 changes: 25 additions & 0 deletions clang/test/CIR/CodeGen/vector-ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1091,3 +1091,28 @@ void foo17() {
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>

void foo19() {
vi4 a;
vi4 b;
vi4 u = __builtin_shufflevector(a, b, 7, 5, 3, 1);
}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b"]
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[SHUF:.*]] = cir.vec.shuffle(%[[TMP_A]], %[[TMP_B]] : !cir.vector<4 x !s32i>) [#cir.int<7> :
// CIR-SAME: !s64i, #cir.int<5> : !s64i, #cir.int<3> : !s64i, #cir.int<1> : !s64i] : !cir.vector<4 x !s32i>

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// LLVM: %[[SHUF:.*]] = shufflevector <4 x i32> %[[TMP_A]], <4 x i32> %[[TMP_B]], <4 x i32> <i32 7, i32 5, i32 3, i32 1>

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// OGCG: %[[SHUF:.*]] = shufflevector <4 x i32> %[[TMP_A]], <4 x i32> %[[TMP_B]], <4 x i32> <i32 7, i32 5, i32 3, i32 1>
25 changes: 25 additions & 0 deletions clang/test/CIR/CodeGen/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,31 @@ void foo17() {
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>

void foo19() {
vi4 a;
vi4 b;
vi4 u = __builtin_shufflevector(a, b, 7, 5, 3, 1);
}

// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b"]
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
// CIR: %[[SHUF:.*]] = cir.vec.shuffle(%[[TMP_A]], %[[TMP_B]] : !cir.vector<4 x !s32i>) [#cir.int<7> :
// CIR-SAME: !s64i, #cir.int<5> : !s64i, #cir.int<3> : !s64i, #cir.int<1> : !s64i] : !cir.vector<4 x !s32i>

// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// LLVM: %[[SHUF:.*]] = shufflevector <4 x i32> %[[TMP_A]], <4 x i32> %[[TMP_B]], <4 x i32> <i32 7, i32 5, i32 3, i32 1>

// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
// OGCG: %[[SHUF:.*]] = shufflevector <4 x i32> %[[TMP_A]], <4 x i32> %[[TMP_B]], <4 x i32> <i32 7, i32 5, i32 3, i32 1>

void foo20() {
vi4 a;
vi4 b;
Expand Down
38 changes: 38 additions & 0 deletions clang/test/CIR/IR/invalid-vector.cir
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,41 @@ module {
cir.global external @vec_b = #cir.zero : !cir.vector<4 x !cir.array<!s32i x 10>>

}

// -----

!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>

module {
cir.func @invalid_vector_shuffle() {
%1 = cir.const #cir.int<1> : !s32i
%2 = cir.const #cir.int<2> : !s32i
%3 = cir.const #cir.int<3> : !s32i
%4 = cir.const #cir.int<4> : !s32i
%vec_1 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
%vec_2 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// expected-error @below {{element types of '!cir.vector<4 x !cir.int<s, 32>>' and '!cir.vector<4 x !cir.int<s, 64>>' don't match}}
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<7> : !s64i, #cir.int<5> : !s64i, #cir.int<3> : !s64i, #cir.int<1> : !s64i] : !cir.vector<4 x !s64i>
cir.return
}
}

// -----

!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>

module {
cir.func @invalid_vector_shuffle() {
%1 = cir.const #cir.int<1> : !s32i
%2 = cir.const #cir.int<2> : !s32i
%3 = cir.const #cir.int<3> : !s32i
%4 = cir.const #cir.int<4> : !s32i
%vec_1 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
%vec_2 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
// expected-error @below {{the number of elements in [#cir.int<7> : !cir.int<s, 64>, #cir.int<5> : !cir.int<s, 64>, #cir.int<3> : !cir.int<s, 64>] and '!cir.vector<4 x !cir.int<s, 64>>' don't match}}
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<7> : !s64i, #cir.int<5> : !s64i, #cir.int<3> : !s64i] : !cir.vector<4 x !s64i>
cir.return
}
}
Loading