Skip to content

[clang][RISCV] Handle target features correctly in CheckBuiltinFunctionCall #141548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 1, 2025

Conversation

4vtomat
Copy link
Member

@4vtomat 4vtomat commented May 27, 2025

Currently we only check the required features passed by command line arguments.
We also need to check the features passed by using target features.

…onCall

Currently we only check the required features passed by command line arguments.
We also need to check the features passed by using target features.
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:RISC-V clang:frontend Language frontend issues, e.g. anything involving "Sema" labels May 27, 2025
@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Brandon Wu (4vtomat)

Changes

Currently we only check the required features passed by command line arguments.
We also need to check the features passed by using target features.


Full diff: https://github.com/llvm/llvm-project/pull/141548.diff

2 Files Affected:

  • (modified) clang/lib/Sema/SemaRISCV.cpp (+29-20)
  • (added) clang/test/Sema/zvk-target-attributes.c (+11)
diff --git a/clang/lib/Sema/SemaRISCV.cpp b/clang/lib/Sema/SemaRISCV.cpp
index 481bf8bd22cc1..ca8d849b40a2a 100644
--- a/clang/lib/Sema/SemaRISCV.cpp
+++ b/clang/lib/Sema/SemaRISCV.cpp
@@ -544,8 +544,10 @@ bool SemaRISCV::CheckLMUL(CallExpr *TheCall, unsigned ArgNum) {
          << Arg->getSourceRange();
 }
 
-static bool CheckInvalidVLENandLMUL(const TargetInfo &TI, CallExpr *TheCall,
-                                    Sema &S, QualType Type, int EGW) {
+static bool CheckInvalidVLENandLMUL(const TargetInfo &TI,
+                                    llvm::StringMap<bool> &FunctionFeatureMap,
+                                    CallExpr *TheCall, Sema &S, QualType Type,
+                                    int EGW) {
   assert((EGW == 128 || EGW == 256) && "EGW can only be 128 or 256 bits");
 
   // LMUL * VLEN >= EGW
@@ -566,7 +568,7 @@ static bool CheckInvalidVLENandLMUL(const TargetInfo &TI, CallExpr *TheCall,
   // Vscale is VLEN/RVVBitsPerBlock.
   unsigned MinRequiredVLEN = VScaleFactor * llvm::RISCV::RVVBitsPerBlock;
   std::string RequiredExt = "zvl" + std::to_string(MinRequiredVLEN) + "b";
-  if (!TI.hasFeature(RequiredExt))
+  if (!TI.hasFeature(RequiredExt) && !FunctionFeatureMap.lookup(RequiredExt))
     return S.Diag(TheCall->getBeginLoc(),
                   diag::err_riscv_type_requires_extension)
            << Type << RequiredExt;
@@ -578,6 +580,10 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
                                          unsigned BuiltinID,
                                          CallExpr *TheCall) {
   ASTContext &Context = getASTContext();
+  const FunctionDecl *FD = SemaRef.getCurFunctionDecl();
+  llvm::StringMap<bool> FunctionFeatureMap;
+  Context.getFunctionFeatureMap(FunctionFeatureMap, FD);
+
   // vmulh.vv, vmulh.vx, vmulhu.vv, vmulhu.vx, vmulhsu.vv, vmulhsu.vx,
   // vsmul.vv, vsmul.vx are not included for EEW=64 in Zve64*.
   switch (BuiltinID) {
@@ -634,10 +640,6 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
     ASTContext::BuiltinVectorTypeInfo Info = Context.getBuiltinVectorTypeInfo(
         TheCall->getType()->castAs<BuiltinType>());
 
-    const FunctionDecl *FD = SemaRef.getCurFunctionDecl();
-    llvm::StringMap<bool> FunctionFeatureMap;
-    Context.getFunctionFeatureMap(FunctionFeatureMap, FD);
-
     if (Context.getTypeSize(Info.ElementType) == 64 && !TI.hasFeature("v") &&
         !FunctionFeatureMap.lookup("v"))
       return Diag(TheCall->getBeginLoc(),
@@ -713,20 +715,24 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
   case RISCVVector::BI__builtin_rvv_vsm4k_vi_tu: {
     QualType Arg0Type = TheCall->getArg(0)->getType();
     QualType Arg1Type = TheCall->getArg(1)->getType();
-    return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
-           CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type, 128) ||
+    return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg0Type, 128) ||
+           CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg1Type, 128) ||
            SemaRef.BuiltinConstantArgRange(TheCall, 2, 0, 31);
   }
   case RISCVVector::BI__builtin_rvv_vsm3c_vi_tu:
   case RISCVVector::BI__builtin_rvv_vsm3c_vi: {
     QualType Arg0Type = TheCall->getArg(0)->getType();
-    return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 256) ||
+    return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg0Type, 256) ||
            SemaRef.BuiltinConstantArgRange(TheCall, 2, 0, 31);
   }
   case RISCVVector::BI__builtin_rvv_vaeskf1_vi:
   case RISCVVector::BI__builtin_rvv_vsm4k_vi: {
     QualType Arg0Type = TheCall->getArg(0)->getType();
-    return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
+    return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg0Type, 128) ||
            SemaRef.BuiltinConstantArgRange(TheCall, 1, 0, 31);
   }
   case RISCVVector::BI__builtin_rvv_vaesdf_vv:
@@ -753,8 +759,10 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
   case RISCVVector::BI__builtin_rvv_vsm4r_vs_tu: {
     QualType Arg0Type = TheCall->getArg(0)->getType();
     QualType Arg1Type = TheCall->getArg(1)->getType();
-    return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
-           CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type, 128);
+    return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg0Type, 128) ||
+           CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg1Type, 128);
   }
   case RISCVVector::BI__builtin_rvv_vsha2ch_vv:
   case RISCVVector::BI__builtin_rvv_vsha2cl_vv:
@@ -768,17 +776,18 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
     ASTContext::BuiltinVectorTypeInfo Info =
         Context.getBuiltinVectorTypeInfo(Arg0Type->castAs<BuiltinType>());
     uint64_t ElemSize = Context.getTypeSize(Info.ElementType);
-    if (ElemSize == 64 && !TI.hasFeature("zvknhb"))
+    if (ElemSize == 64 && !TI.hasFeature("zvknhb") &&
+        !FunctionFeatureMap.lookup("zvknhb"))
       return Diag(TheCall->getBeginLoc(),
                   diag::err_riscv_builtin_requires_extension)
              << /* IsExtension */ true << TheCall->getSourceRange() << "zvknhb";
 
-    return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type,
-                                   ElemSize * 4) ||
-           CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type,
-                                   ElemSize * 4) ||
-           CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg2Type,
-                                   ElemSize * 4);
+    return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg0Type, ElemSize * 4) ||
+           CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg1Type, ElemSize * 4) ||
+           CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg2Type, ElemSize * 4);
   }
 
   case RISCVVector::BI__builtin_rvv_sf_vc_i_se:
diff --git a/clang/test/Sema/zvk-target-attributes.c b/clang/test/Sema/zvk-target-attributes.c
new file mode 100644
index 0000000000000..dad2e5b16ac87
--- /dev/null
+++ b/clang/test/Sema/zvk-target-attributes.c
@@ -0,0 +1,11 @@
+// REQUIRES: riscv-registered-target
+// RUN: %clang_cc1 -triple riscv64 -target-feature +zvknha %s -fsyntax-only -verify
+
+#include <riscv_vector.h>
+
+// expected-no-diagnostics
+
+__attribute__((target("arch=+zvl128b")))
+void test_zvk_features(vuint32m1_t vd, vuint32m1_t vs2, vuint32m1_t vs1, size_t vl) {
+  __riscv_vsha2ch_vv_u32m1(vd, vs2, vs1, vl);
+}

@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-clang

Author: Brandon Wu (4vtomat)

Changes

Currently we only check the required features passed by command line arguments.
We also need to check the features passed by using target features.


Full diff: https://github.com/llvm/llvm-project/pull/141548.diff

2 Files Affected:

  • (modified) clang/lib/Sema/SemaRISCV.cpp (+29-20)
  • (added) clang/test/Sema/zvk-target-attributes.c (+11)
diff --git a/clang/lib/Sema/SemaRISCV.cpp b/clang/lib/Sema/SemaRISCV.cpp
index 481bf8bd22cc1..ca8d849b40a2a 100644
--- a/clang/lib/Sema/SemaRISCV.cpp
+++ b/clang/lib/Sema/SemaRISCV.cpp
@@ -544,8 +544,10 @@ bool SemaRISCV::CheckLMUL(CallExpr *TheCall, unsigned ArgNum) {
          << Arg->getSourceRange();
 }
 
-static bool CheckInvalidVLENandLMUL(const TargetInfo &TI, CallExpr *TheCall,
-                                    Sema &S, QualType Type, int EGW) {
+static bool CheckInvalidVLENandLMUL(const TargetInfo &TI,
+                                    llvm::StringMap<bool> &FunctionFeatureMap,
+                                    CallExpr *TheCall, Sema &S, QualType Type,
+                                    int EGW) {
   assert((EGW == 128 || EGW == 256) && "EGW can only be 128 or 256 bits");
 
   // LMUL * VLEN >= EGW
@@ -566,7 +568,7 @@ static bool CheckInvalidVLENandLMUL(const TargetInfo &TI, CallExpr *TheCall,
   // Vscale is VLEN/RVVBitsPerBlock.
   unsigned MinRequiredVLEN = VScaleFactor * llvm::RISCV::RVVBitsPerBlock;
   std::string RequiredExt = "zvl" + std::to_string(MinRequiredVLEN) + "b";
-  if (!TI.hasFeature(RequiredExt))
+  if (!TI.hasFeature(RequiredExt) && !FunctionFeatureMap.lookup(RequiredExt))
     return S.Diag(TheCall->getBeginLoc(),
                   diag::err_riscv_type_requires_extension)
            << Type << RequiredExt;
@@ -578,6 +580,10 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
                                          unsigned BuiltinID,
                                          CallExpr *TheCall) {
   ASTContext &Context = getASTContext();
+  const FunctionDecl *FD = SemaRef.getCurFunctionDecl();
+  llvm::StringMap<bool> FunctionFeatureMap;
+  Context.getFunctionFeatureMap(FunctionFeatureMap, FD);
+
   // vmulh.vv, vmulh.vx, vmulhu.vv, vmulhu.vx, vmulhsu.vv, vmulhsu.vx,
   // vsmul.vv, vsmul.vx are not included for EEW=64 in Zve64*.
   switch (BuiltinID) {
@@ -634,10 +640,6 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
     ASTContext::BuiltinVectorTypeInfo Info = Context.getBuiltinVectorTypeInfo(
         TheCall->getType()->castAs<BuiltinType>());
 
-    const FunctionDecl *FD = SemaRef.getCurFunctionDecl();
-    llvm::StringMap<bool> FunctionFeatureMap;
-    Context.getFunctionFeatureMap(FunctionFeatureMap, FD);
-
     if (Context.getTypeSize(Info.ElementType) == 64 && !TI.hasFeature("v") &&
         !FunctionFeatureMap.lookup("v"))
       return Diag(TheCall->getBeginLoc(),
@@ -713,20 +715,24 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
   case RISCVVector::BI__builtin_rvv_vsm4k_vi_tu: {
     QualType Arg0Type = TheCall->getArg(0)->getType();
     QualType Arg1Type = TheCall->getArg(1)->getType();
-    return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
-           CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type, 128) ||
+    return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg0Type, 128) ||
+           CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg1Type, 128) ||
            SemaRef.BuiltinConstantArgRange(TheCall, 2, 0, 31);
   }
   case RISCVVector::BI__builtin_rvv_vsm3c_vi_tu:
   case RISCVVector::BI__builtin_rvv_vsm3c_vi: {
     QualType Arg0Type = TheCall->getArg(0)->getType();
-    return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 256) ||
+    return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg0Type, 256) ||
            SemaRef.BuiltinConstantArgRange(TheCall, 2, 0, 31);
   }
   case RISCVVector::BI__builtin_rvv_vaeskf1_vi:
   case RISCVVector::BI__builtin_rvv_vsm4k_vi: {
     QualType Arg0Type = TheCall->getArg(0)->getType();
-    return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
+    return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg0Type, 128) ||
            SemaRef.BuiltinConstantArgRange(TheCall, 1, 0, 31);
   }
   case RISCVVector::BI__builtin_rvv_vaesdf_vv:
@@ -753,8 +759,10 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
   case RISCVVector::BI__builtin_rvv_vsm4r_vs_tu: {
     QualType Arg0Type = TheCall->getArg(0)->getType();
     QualType Arg1Type = TheCall->getArg(1)->getType();
-    return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type, 128) ||
-           CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type, 128);
+    return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg0Type, 128) ||
+           CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg1Type, 128);
   }
   case RISCVVector::BI__builtin_rvv_vsha2ch_vv:
   case RISCVVector::BI__builtin_rvv_vsha2cl_vv:
@@ -768,17 +776,18 @@ bool SemaRISCV::CheckBuiltinFunctionCall(const TargetInfo &TI,
     ASTContext::BuiltinVectorTypeInfo Info =
         Context.getBuiltinVectorTypeInfo(Arg0Type->castAs<BuiltinType>());
     uint64_t ElemSize = Context.getTypeSize(Info.ElementType);
-    if (ElemSize == 64 && !TI.hasFeature("zvknhb"))
+    if (ElemSize == 64 && !TI.hasFeature("zvknhb") &&
+        !FunctionFeatureMap.lookup("zvknhb"))
       return Diag(TheCall->getBeginLoc(),
                   diag::err_riscv_builtin_requires_extension)
              << /* IsExtension */ true << TheCall->getSourceRange() << "zvknhb";
 
-    return CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg0Type,
-                                   ElemSize * 4) ||
-           CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg1Type,
-                                   ElemSize * 4) ||
-           CheckInvalidVLENandLMUL(TI, TheCall, SemaRef, Arg2Type,
-                                   ElemSize * 4);
+    return CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg0Type, ElemSize * 4) ||
+           CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg1Type, ElemSize * 4) ||
+           CheckInvalidVLENandLMUL(TI, FunctionFeatureMap, TheCall, SemaRef,
+                                   Arg2Type, ElemSize * 4);
   }
 
   case RISCVVector::BI__builtin_rvv_sf_vc_i_se:
diff --git a/clang/test/Sema/zvk-target-attributes.c b/clang/test/Sema/zvk-target-attributes.c
new file mode 100644
index 0000000000000..dad2e5b16ac87
--- /dev/null
+++ b/clang/test/Sema/zvk-target-attributes.c
@@ -0,0 +1,11 @@
+// REQUIRES: riscv-registered-target
+// RUN: %clang_cc1 -triple riscv64 -target-feature +zvknha %s -fsyntax-only -verify
+
+#include <riscv_vector.h>
+
+// expected-no-diagnostics
+
+__attribute__((target("arch=+zvl128b")))
+void test_zvk_features(vuint32m1_t vd, vuint32m1_t vs2, vuint32m1_t vs1, size_t vl) {
+  __riscv_vsha2ch_vv_u32m1(vd, vs2, vs1, vl);
+}

// expected-no-diagnostics

__attribute__((target("arch=+zvl128b")))
void test_zvk_features(vuint32m1_t vd, vuint32m1_t vs2, vuint32m1_t vs1, size_t vl) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we get if without this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have following error: RISC-V type 'vuint32m1_t' (aka '__rvv_uint32m1_t') requires the 'zvl128b' extension
That's the extra check for vector crypto extensions

Copy link
Contributor

@wangpc-pp wangpc-pp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@4vtomat 4vtomat merged commit 88aa5cb into llvm:main Jun 1, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:RISC-V clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants