Skip to content

Commit 848a6d7

Browse files
authored
Fix handling of OpenCL convert_ builtins (#2443)
This change fixes cases when the input function name does not match OpenCL spec but starts with `convert_` prefix, e.g. `convert_float_42`.
1 parent af06f03 commit 848a6d7

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "llvm/Support/Debug.h"
5353

5454
#include <algorithm>
55+
#include <regex>
5556
#include <set>
5657

5758
using namespace llvm;
@@ -724,6 +725,13 @@ void OCLToSPIRVBase::visitCallBarrier(CallInst *CI) {
724725

725726
void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
726727
StringRef DemangledName) {
728+
// OpenCL Explicit Conversions (6.4.3) formed as below for scalars:
729+
// destType convert_destType<_sat><_roundingMode>(sourceType)
730+
// and for vector type:
731+
// destTypeN convert_destTypeN<_sat><_roundingMode>(sourceTypeN)
732+
// If the demangled name is not matching the suggested pattern and does not
733+
// meet allowed destination type restrictions - this is not an OpenCL builtin,
734+
// return from the function and translate such CallInst as a function call.
727735
if (eraseUselessConvert(CI, MangledName, DemangledName))
728736
return;
729737
Op OC = OpNop;
@@ -734,16 +742,56 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
734742
if (auto *VecTy = dyn_cast<VectorType>(SrcTy))
735743
SrcTy = VecTy->getElementType();
736744
auto IsTargetInt = isa<IntegerType>(TargetTy);
745+
auto TargetSigned = DemangledName[8] != 'u';
737746

738747
std::string TargetTyName(
739748
DemangledName.substr(strlen(kOCLBuiltinName::ConvertPrefix)));
740749
auto FirstUnderscoreLoc = TargetTyName.find('_');
741750
if (FirstUnderscoreLoc != std::string::npos)
742751
TargetTyName = TargetTyName.substr(0, FirstUnderscoreLoc);
752+
753+
// Validate target type name
754+
std::regex Expr("([a-z]+)([0-9]*)$");
755+
std::smatch DestTyMatch;
756+
if (!std::regex_match(TargetTyName, DestTyMatch, Expr))
757+
return;
758+
759+
// The first sub_match is the whole string; the next
760+
// sub_match is the first parenthesized expression.
761+
std::string DestTy = DestTyMatch[1].str();
762+
763+
// check it's valid type name
764+
static std::unordered_set<std::string> ValidTypes = {
765+
"float", "double", "half", "char", "uchar", "short",
766+
"ushort", "int", "uint", "long", "ulong"};
767+
768+
if (ValidTypes.find(DestTy) == ValidTypes.end())
769+
return;
770+
771+
// check that it's allowed vector size
772+
std::string VecSize = DestTyMatch[2].str();
773+
if (!VecSize.empty()) {
774+
int Size = stoi(VecSize);
775+
switch (Size) {
776+
case 2:
777+
case 3:
778+
case 4:
779+
case 8:
780+
case 16:
781+
break;
782+
default:
783+
return;
784+
}
785+
}
786+
DemangledName = DemangledName.drop_front(
787+
strlen(kOCLBuiltinName::ConvertPrefix) + TargetTyName.size());
743788
TargetTyName = std::string("_R") + TargetTyName;
744789

790+
if (!DemangledName.empty() && !DemangledName.starts_with("_sat") &&
791+
!DemangledName.starts_with("_rt"))
792+
return;
793+
745794
std::string Sat = DemangledName.find("_sat") != StringRef::npos ? "_sat" : "";
746-
auto TargetSigned = DemangledName[8] != 'u';
747795
if (isa<IntegerType>(SrcTy)) {
748796
bool Signed = isLastFuncParamSigned(MangledName);
749797
if (IsTargetInt) {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
; This test checks that functions with `convert_` prefix are translated as
2+
; OpenCL builtins only in case they match the specification. Otherwise, we
3+
; expect such functions to be translated to SPIR-V FunctionCall.
4+
5+
; RUN: llvm-as %s -o %t.bc
6+
; RUN: llvm-spirv %t.bc -o %t.spv
7+
; RUN: spirv-val %t.spv
8+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
9+
; RUN: FileCheck < %t.spt %s -check-prefix=CHECK-SPIRV
10+
; RUN: llvm-spirv %t.spv -r -o - | llvm-dis -o %t.rev.ll
11+
; RUN: FileCheck < %t.rev.ll %s -check-prefix=CHECK-LLVM
12+
13+
; CHECK-SPIRV: Name [[#Func:]] "_Z18convert_float_func"
14+
; CHECK-SPIRV: TypeVoid [[#VoidTy:]]
15+
; CHECK-SPIRV: TypeFloat [[#FloatTy:]] 32
16+
17+
; CHECK-SPIRV: Function [[#VoidTy]] [[#Func]]
18+
; CHECK-SPIRV: ConvertSToF [[#FloatTy]] [[#ConvertId:]] [[#]]
19+
; CHECK-SPIRV: FunctionCall [[#VoidTy]] [[#]] [[#Func]] [[#ConvertId]]
20+
; CHECK-SPIRV-NOT: FConvert
21+
22+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
23+
target triple = "spir"
24+
25+
; Function Attrs: convergent noinline norecurse nounwind optnone
26+
define dso_local spir_func void @_Z18convert_float_func(float noundef %x) #0 {
27+
entry:
28+
%x.addr = alloca float, align 4
29+
store float %x, ptr %x.addr, align 4
30+
ret void
31+
}
32+
33+
; Function Attrs: convergent noinline norecurse nounwind optnone
34+
define dso_local spir_func void @convert_int_bf16(i32 noundef %x) #0 {
35+
entry:
36+
%x.addr = alloca i32, align 4
37+
store i32 %x, ptr %x.addr, align 4
38+
%0 = load i32, ptr %x.addr, align 4
39+
; CHECK-LLVM: %[[Call:[a-z]+]] = sitofp i32 %[[#]] to float
40+
%call = call spir_func float @_Z13convert_floati(i32 noundef %0) #1
41+
; CHECK-LLVM: call spir_func void @_Z18convert_float_func(float %[[Call]])
42+
call spir_func void @_Z18convert_float_func(float noundef %call) #0
43+
ret void
44+
}
45+
46+
; Function Attrs: convergent nounwind willreturn memory(none)
47+
declare spir_func float @_Z13convert_floati(i32 noundef) #1
48+
49+
attributes #0 = { convergent nounwind }
50+
attributes #1 = { convergent nounwind willreturn memory(none) }
51+
52+
!opencl.ocl.version = !{!0}
53+
!opencl.spir.version = !{!0}
54+
55+
!0 = !{i32 3, i32 0}

0 commit comments

Comments
 (0)