-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[clang][RISCV] Handle target features correctly in CheckBuiltinFunctionCall #141548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…onCall Currently we only check the required features passed by command line arguments. We also need to check the features passed by using target features.
@llvm/pr-subscribers-backend-risc-v Author: Brandon Wu (4vtomat) ChangesCurrently we only check the required features passed by command line arguments. Full diff: https://github.com/llvm/llvm-project/pull/141548.diff 2 Files Affected:
diff --git a/clang/lib/Sema/SemaRISCV.cpp b/clang/lib/Sema/SemaRISCV.cpp
index 481bf8bd22cc1..ca8d849b40a2a 100644
--- a/clang/lib/Sema/SemaRISCV.cpp
+++ b/clang/lib/Sema/SemaRISCV.cpp
@@ -544,8 +544,10 @@ bool SemaRISCV::CheckLMUL(CallExpr *TheCall, unsigned ArgNum) {
<< Arg->getSourceRange();
}
-static bool CheckInvalidVLENandLMUL(const TargetInfo &TI, CallExpr *TheCall,
- Sema &S, QualType Type, int EGW) {
+static bool CheckInvalidVLENandLMUL(const TargetInfo &TI,
+ llvm::StringMap<bool> &FunctionFeatureMap,
+ CallExpr *TheCall, Sema &S, QualType Type,
+ int EGW) {
assert((EGW == 128 || EGW == 256) && "EGW can only be 128 or 256 bits");
// LMUL * VLEN >= EGW
@@ -566,7 +568,7 @@ static bool CheckInvalidVLENandLMUL(const TargetInfo &TI, CallExpr *TheCall,
// Vscale is VLEN/RVVBitsPerBlock.
unsigned MinRequiredVLEN = VScaleFactor * llvm::RISCV::RVVBitsPerBlock;
std::string RequiredExt = "zvl" + std::to_string(MinRequiredVLEN) + "b";
- if (!TI.hasFeature(RequiredExt))
+ if (!TI.hasFeature(RequiredExt) && !FunctionFeatureMap.lookup(RequiredExt))
return S.Diag(TheCall->getBeginLoc(),
diag::err_riscv_type_requires_extension)
<< Type << RequiredExt;
@@ -578,6 +580,10 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
unsigned BuiltinID,
CallExpr *TheCall) {
ASTContext &Context = getASTContext();
+ const FunctionDecl *FD = SemaRef.getCurFunctionDecl();
+ llvm::StringMap<bool> FunctionFeatureMap;
+ Context.getFunctionFeatureMap(FunctionFeatureMap, FD);
+
// vmulh.vv, vmulh.vx, vmulhu.vv, vmulhu.vx, vmulhsu.vv, vmulhsu.vx,
// vsmul.vv, vsmul.vx are not included for EEW=64 in Zve64*.
switch (BuiltinID) {
@@ -634,10 +640,6 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
ASTContext::BuiltinVectorTypeInfo Info = Context.getBuiltinVectorTypeInfo(
TheCall->getType()->castAs<BuiltinType>());
- const FunctionDecl *FD = SemaRef.getCurFunctionDecl();
- llvm::StringMap<bool> FunctionFeatureMap;
- Context.getFunctionFeatureMap(FunctionFeatureMap, FD);
-
if (Context.getTypeSize(Info.ElementType) == 64 && !TI.hasFeature("v") &&
!FunctionFeatureMap.lookup("v"))
return Diag(TheCall->getBeginLoc(),
@@ -713,20 +715,24 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
case RISCVVector::BI__builtin_rvv_vsm4k_vi_tu: {
QualType Arg0Type = TheCall->getArg(0)->getType();
QualType Arg1Type = TheCall->getArg(1)->getType();
- return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
- CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type, 128) ||
+ return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg0Type, 128) ||
+ CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg1Type, 128) ||
SemaRef.BuiltinConstantArgRange(TheCall, 2, 0, 31);
}
case RISCVVector::BI__builtin_rvv_vsm3c_vi_tu:
case RISCVVector::BI__builtin_rvv_vsm3c_vi: {
QualType Arg0Type = TheCall->getArg(0)->getType();
- return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 256) ||
+ return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg0Type, 256) ||
SemaRef.BuiltinConstantArgRange(TheCall, 2, 0, 31);
}
case RISCVVector::BI__builtin_rvv_vaeskf1_vi:
case RISCVVector::BI__builtin_rvv_vsm4k_vi: {
QualType Arg0Type = TheCall->getArg(0)->getType();
- return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
+ return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg0Type, 128) ||
SemaRef.BuiltinConstantArgRange(TheCall, 1, 0, 31);
}
case RISCVVector::BI__builtin_rvv_vaesdf_vv:
@@ -753,8 +759,10 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
case RISCVVector::BI__builtin_rvv_vsm4r_vs_tu: {
QualType Arg0Type = TheCall->getArg(0)->getType();
QualType Arg1Type = TheCall->getArg(1)->getType();
- return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
- CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type, 128);
+ return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg0Type, 128) ||
+ CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg1Type, 128);
}
case RISCVVector::BI__builtin_rvv_vsha2ch_vv:
case RISCVVector::BI__builtin_rvv_vsha2cl_vv:
@@ -768,17 +776,18 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
ASTContext::BuiltinVectorTypeInfo Info =
Context.getBuiltinVectorTypeInfo(Arg0Type->castAs<BuiltinType>());
uint64_t ElemSize = Context.getTypeSize(Info.ElementType);
- if (ElemSize == 64 && !TI.hasFeature("zvknhb"))
+ if (ElemSize == 64 && !TI.hasFeature("zvknhb") &&
+ !FunctionFeatureMap.lookup("zvknhb"))
return Diag(TheCall->getBeginLoc(),
diag::err_riscv_builtin_requires_extension)
<< /* IsExtension */ true << TheCall->getSourceRange() << "zvknhb";
- return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type,
- ElemSize * 4) ||
- CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type,
- ElemSize * 4) ||
- CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg2Type,
- ElemSize * 4);
+ return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg0Type, ElemSize * 4) ||
+ CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg1Type, ElemSize * 4) ||
+ CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg2Type, ElemSize * 4);
}
case RISCVVector::BI__builtin_rvv_sf_vc_i_se:
diff --git a/clang/test/Sema/zvk-target-attributes.c b/clang/test/Sema/zvk-target-attributes.c
new file mode 100644
index 0000000000000..dad2e5b16ac87
--- /dev/null
+++ b/clang/test/Sema/zvk-target-attributes.c
@@ -0,0 +1,11 @@
+// REQUIRES: riscv-registered-target
+// RUN: %clang_cc1 -triple riscv64 -target-feature +zvknha %s -fsyntax-only -verify
+
+#include <riscv_vector.h>
+
+// expected-no-diagnostics
+
+__attribute__((target("arch=+zvl128b")))
+void test_zvk_features(vuint32m1_t vd, vuint32m1_t vs2, vuint32m1_t vs1, size_t vl) {
+ __riscv_vsha2ch_vv_u32m1(vd, vs2, vs1, vl);
+}
|
@llvm/pr-subscribers-clang Author: Brandon Wu (4vtomat) ChangesCurrently we only check the required features passed by command line arguments. Full diff: https://github.com/llvm/llvm-project/pull/141548.diff 2 Files Affected:
diff --git a/clang/lib/Sema/SemaRISCV.cpp b/clang/lib/Sema/SemaRISCV.cpp
index 481bf8bd22cc1..ca8d849b40a2a 100644
--- a/clang/lib/Sema/SemaRISCV.cpp
+++ b/clang/lib/Sema/SemaRISCV.cpp
@@ -544,8 +544,10 @@ bool SemaRISCV::CheckLMUL(CallExpr *TheCall, unsigned ArgNum) {
<< Arg->getSourceRange();
}
-static bool CheckInvalidVLENandLMUL(const TargetInfo &TI, CallExpr *TheCall,
- Sema &S, QualType Type, int EGW) {
+static bool CheckInvalidVLENandLMUL(const TargetInfo &TI,
+ llvm::StringMap<bool> &FunctionFeatureMap,
+ CallExpr *TheCall, Sema &S, QualType Type,
+ int EGW) {
assert((EGW == 128 || EGW == 256) && "EGW can only be 128 or 256 bits");
// LMUL * VLEN >= EGW
@@ -566,7 +568,7 @@ static bool CheckInvalidVLENandLMUL(const TargetInfo &TI, CallExpr *TheCall,
// Vscale is VLEN/RVVBitsPerBlock.
unsigned MinRequiredVLEN = VScaleFactor * llvm::RISCV::RVVBitsPerBlock;
std::string RequiredExt = "zvl" + std::to_string(MinRequiredVLEN) + "b";
- if (!TI.hasFeature(RequiredExt))
+ if (!TI.hasFeature(RequiredExt) && !FunctionFeatureMap.lookup(RequiredExt))
return S.Diag(TheCall->getBeginLoc(),
diag::err_riscv_type_requires_extension)
<< Type << RequiredExt;
@@ -578,6 +580,10 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
unsigned BuiltinID,
CallExpr *TheCall) {
ASTContext &Context = getASTContext();
+ const FunctionDecl *FD = SemaRef.getCurFunctionDecl();
+ llvm::StringMap<bool> FunctionFeatureMap;
+ Context.getFunctionFeatureMap(FunctionFeatureMap, FD);
+
// vmulh.vv, vmulh.vx, vmulhu.vv, vmulhu.vx, vmulhsu.vv, vmulhsu.vx,
// vsmul.vv, vsmul.vx are not included for EEW=64 in Zve64*.
switch (BuiltinID) {
@@ -634,10 +640,6 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
ASTContext::BuiltinVectorTypeInfo Info = Context.getBuiltinVectorTypeInfo(
TheCall->getType()->castAs<BuiltinType>());
- const FunctionDecl *FD = SemaRef.getCurFunctionDecl();
- llvm::StringMap<bool> FunctionFeatureMap;
- Context.getFunctionFeatureMap(FunctionFeatureMap, FD);
-
if (Context.getTypeSize(Info.ElementType) == 64 && !TI.hasFeature("v") &&
!FunctionFeatureMap.lookup("v"))
return Diag(TheCall->getBeginLoc(),
@@ -713,20 +715,24 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
case RISCVVector::BI__builtin_rvv_vsm4k_vi_tu: {
QualType Arg0Type = TheCall->getArg(0)->getType();
QualType Arg1Type = TheCall->getArg(1)->getType();
- return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
- CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type, 128) ||
+ return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg0Type, 128) ||
+ CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg1Type, 128) ||
SemaRef.BuiltinConstantArgRange(TheCall, 2, 0, 31);
}
case RISCVVector::BI__builtin_rvv_vsm3c_vi_tu:
case RISCVVector::BI__builtin_rvv_vsm3c_vi: {
QualType Arg0Type = TheCall->getArg(0)->getType();
- return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 256) ||
+ return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg0Type, 256) ||
SemaRef.BuiltinConstantArgRange(TheCall, 2, 0, 31);
}
case RISCVVector::BI__builtin_rvv_vaeskf1_vi:
case RISCVVector::BI__builtin_rvv_vsm4k_vi: {
QualType Arg0Type = TheCall->getArg(0)->getType();
- return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
+ return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg0Type, 128) ||
SemaRef.BuiltinConstantArgRange(TheCall, 1, 0, 31);
}
case RISCVVector::BI__builtin_rvv_vaesdf_vv:
@@ -753,8 +759,10 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
case RISCVVector::BI__builtin_rvv_vsm4r_vs_tu: {
QualType Arg0Type = TheCall->getArg(0)->getType();
QualType Arg1Type = TheCall->getArg(1)->getType();
- return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
- CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type, 128);
+ return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg0Type, 128) ||
+ CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg1Type, 128);
}
case RISCVVector::BI__builtin_rvv_vsha2ch_vv:
case RISCVVector::BI__builtin_rvv_vsha2cl_vv:
@@ -768,17 +776,18 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
ASTContext::BuiltinVectorTypeInfo Info =
Context.getBuiltinVectorTypeInfo(Arg0Type->castAs<BuiltinType>());
uint64_t ElemSize = Context.getTypeSize(Info.ElementType);
- if (ElemSize == 64 && !TI.hasFeature("zvknhb"))
+ if (ElemSize == 64 && !TI.hasFeature("zvknhb") &&
+ !FunctionFeatureMap.lookup("zvknhb"))
return Diag(TheCall->getBeginLoc(),
diag::err_riscv_builtin_requires_extension)
<< /* IsExtension */ true << TheCall->getSourceRange() << "zvknhb";
- return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type,
- ElemSize * 4) ||
- CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type,
- ElemSize * 4) ||
- CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg2Type,
- ElemSize * 4);
+ return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg0Type, ElemSize * 4) ||
+ CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg1Type, ElemSize * 4) ||
+ CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+ Arg2Type, ElemSize * 4);
}
case RISCVVector::BI__builtin_rvv_sf_vc_i_se:
diff --git a/clang/test/Sema/zvk-target-attributes.c b/clang/test/Sema/zvk-target-attributes.c
new file mode 100644
index 0000000000000..dad2e5b16ac87
--- /dev/null
+++ b/clang/test/Sema/zvk-target-attributes.c
@@ -0,0 +1,11 @@
+// REQUIRES: riscv-registered-target
+// RUN: %clang_cc1 -triple riscv64 -target-feature +zvknha %s -fsyntax-only -verify
+
+#include <riscv_vector.h>
+
+// expected-no-diagnostics
+
+__attribute__((target("arch=+zvl128b")))
+void test_zvk_features(vuint32m1_t vd, vuint32m1_t vs2, vuint32m1_t vs1, size_t vl) {
+ __riscv_vsha2ch_vv_u32m1(vd, vs2, vs1, vl);
+}
|
// expected-no-diagnostics | ||
|
||
__attribute__((target("arch=+zvl128b"))) | ||
void test_zvk_features(vuint32m1_t vd, vuint32m1_t vs2, vuint32m1_t vs1, size_t vl) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do we get if without this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll have following error: RISC-V type 'vuint32m1_t' (aka '__rvv_uint32m1_t') requires the 'zvl128b' extension
That's the extra check for vector crypto extensions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Currently we only check the required features passed by command line arguments.
We also need to check the features passed by using target features.