52
52
#include " llvm/Support/Debug.h"
53
53
54
54
#include < algorithm>
55
+ #include < regex>
55
56
#include < set>
56
57
57
58
using namespace llvm ;
@@ -724,6 +725,13 @@ void OCLToSPIRVBase::visitCallBarrier(CallInst *CI) {
724
725
725
726
void OCLToSPIRVBase::visitCallConvert (CallInst *CI, StringRef MangledName,
726
727
StringRef DemangledName) {
728
+ // OpenCL Explicit Conversions (6.4.3) formed as below for scalars:
729
+ // destType convert_destType<_sat><_roundingMode>(sourceType)
730
+ // and for vector type:
731
+ // destTypeN convert_destTypeN<_sat><_roundingMode>(sourceTypeN)
732
+ // If the demangled name is not matching the suggested pattern and does not
733
+ // meet allowed destination type restrictions - this is not an OpenCL builtin,
734
+ // return from the function and translate such CallInst as a function call.
727
735
if (eraseUselessConvert (CI, MangledName, DemangledName))
728
736
return ;
729
737
Op OC = OpNop;
@@ -734,16 +742,56 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
734
742
if (auto *VecTy = dyn_cast<VectorType>(SrcTy))
735
743
SrcTy = VecTy->getElementType ();
736
744
auto IsTargetInt = isa<IntegerType>(TargetTy);
745
+ auto TargetSigned = DemangledName[8 ] != ' u' ;
737
746
738
747
std::string TargetTyName (
739
748
DemangledName.substr (strlen (kOCLBuiltinName ::ConvertPrefix)));
740
749
auto FirstUnderscoreLoc = TargetTyName.find (' _' );
741
750
if (FirstUnderscoreLoc != std::string::npos)
742
751
TargetTyName = TargetTyName.substr (0 , FirstUnderscoreLoc);
752
+
753
+ // Validate target type name
754
+ std::regex Expr (" ([a-z]+)([0-9]*)$" );
755
+ std::smatch DestTyMatch;
756
+ if (!std::regex_match (TargetTyName, DestTyMatch, Expr))
757
+ return ;
758
+
759
+ // The first sub_match is the whole string; the next
760
+ // sub_match is the first parenthesized expression.
761
+ std::string DestTy = DestTyMatch[1 ].str ();
762
+
763
+ // check it's valid type name
764
+ static std::unordered_set<std::string> ValidTypes = {
765
+ " float" , " double" , " half" , " char" , " uchar" , " short" ,
766
+ " ushort" , " int" , " uint" , " long" , " ulong" };
767
+
768
+ if (ValidTypes.find (DestTy) == ValidTypes.end ())
769
+ return ;
770
+
771
+ // check that it's allowed vector size
772
+ std::string VecSize = DestTyMatch[2 ].str ();
773
+ if (!VecSize.empty ()) {
774
+ int Size = stoi (VecSize);
775
+ switch (Size) {
776
+ case 2 :
777
+ case 3 :
778
+ case 4 :
779
+ case 8 :
780
+ case 16 :
781
+ break ;
782
+ default :
783
+ return ;
784
+ }
785
+ }
786
+ DemangledName = DemangledName.drop_front (
787
+ strlen (kOCLBuiltinName ::ConvertPrefix) + TargetTyName.size ());
743
788
TargetTyName = std::string (" _R" ) + TargetTyName;
744
789
790
+ if (!DemangledName.empty () && !DemangledName.starts_with (" _sat" ) &&
791
+ !DemangledName.starts_with (" _rt" ))
792
+ return ;
793
+
745
794
std::string Sat = DemangledName.find (" _sat" ) != StringRef::npos ? " _sat" : " " ;
746
- auto TargetSigned = DemangledName[8 ] != ' u' ;
747
795
if (isa<IntegerType>(SrcTy)) {
748
796
bool Signed = isLastFuncParamSigned (MangledName);
749
797
if (IsTargetInt) {
0 commit comments