Skip to content

Commit d5a7f20

Browse files
authored
[SYCL] Fix specialization constants struct members (#2232)
* [SYCL] Fix specialization constants struct members (#2232) FE crashed on attempt to create initializer for struct with spec constant members because there was no initializers for spec const fields. Added default initialization for spec constants.
1 parent 0b8947a commit d5a7f20

File tree

4 files changed

+94
-20
lines changed

4 files changed

+94
-20
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,13 +1455,8 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
14551455
InitExprs.push_back(ILE);
14561456
}
14571457

1458-
void createSpecialMethodCall(const CXXRecordDecl *SpecialClass, Expr *Base,
1459-
const std::string &MethodName,
1460-
FieldDecl *Field) {
1461-
CXXMethodDecl *Method = getMethodByName(SpecialClass, MethodName);
1462-
assert(Method &&
1463-
"The accessor/sampler/stream must have the __init method. Stream"
1464-
" must also have __finalize method");
1458+
CXXMemberCallExpr *createSpecialMethodCall(Expr *Base, CXXMethodDecl *Method,
1459+
FieldDecl *Field) {
14651460
unsigned NumParams = Method->getNumParams();
14661461
llvm::SmallVector<Expr *, 4> ParamDREs(NumParams);
14671462
llvm::ArrayRef<ParmVarDecl *> KernelParameters =
@@ -1485,10 +1480,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
14851480
CXXMemberCallExpr *Call = CXXMemberCallExpr::Create(
14861481
SemaRef.Context, MethodME, ParamStmts, ResultTy, VK, SourceLocation(),
14871482
FPOptionsOverride());
1488-
if (MethodName == FinalizeMethodName)
1489-
FinalizeStmts.push_back(Call);
1490-
else
1491-
BodyStmts.push_back(Call);
1483+
return Call;
14921484
}
14931485

14941486
// FIXME Avoid creation of kernel obj clone.
@@ -1517,8 +1509,12 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
15171509
ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None);
15181510
InitExprs.push_back(MemberInit.get());
15191511

1520-
createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName,
1521-
FD);
1512+
CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName);
1513+
if (InitMethod) {
1514+
CXXMemberCallExpr *InitCall =
1515+
createSpecialMethodCall(MemberExprBases.back(), InitMethod, FD);
1516+
BodyStmts.push_back(InitCall);
1517+
}
15221518
return true;
15231519
}
15241520

@@ -1535,8 +1531,12 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
15351531
ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None);
15361532
InitExprs.push_back(MemberInit.get());
15371533

1538-
createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName,
1539-
nullptr);
1534+
CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName);
1535+
if (InitMethod) {
1536+
CXXMemberCallExpr *InitCall =
1537+
createSpecialMethodCall(MemberExprBases.back(), InitMethod, nullptr);
1538+
BodyStmts.push_back(InitCall);
1539+
}
15401540
return true;
15411541
}
15421542

@@ -1578,14 +1578,27 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
15781578
return handleSpecialType(FD, Ty);
15791579
}
15801580

1581+
bool handleSyclSpecConstantType(FieldDecl *FD, QualType Ty) final {
1582+
return handleSpecialType(FD, Ty);
1583+
}
1584+
15811585
bool handleSyclStreamType(FieldDecl *FD, QualType Ty) final {
15821586
const auto *StreamDecl = Ty->getAsCXXRecordDecl();
15831587
createExprForStructOrScalar(FD);
15841588
size_t NumBases = MemberExprBases.size();
1585-
createSpecialMethodCall(StreamDecl, MemberExprBases[NumBases - 2],
1586-
InitMethodName, FD);
1587-
createSpecialMethodCall(StreamDecl, MemberExprBases[NumBases - 2],
1588-
FinalizeMethodName, FD);
1589+
CXXMethodDecl *InitMethod = getMethodByName(StreamDecl, InitMethodName);
1590+
if (InitMethod) {
1591+
CXXMemberCallExpr *InitCall =
1592+
createSpecialMethodCall(MemberExprBases.back(), InitMethod, FD);
1593+
BodyStmts.push_back(InitCall);
1594+
}
1595+
CXXMethodDecl *FinalizeMethod =
1596+
getMethodByName(StreamDecl, FinalizeMethodName);
1597+
if (FinalizeMethod) {
1598+
CXXMemberCallExpr *FinalizeCall = createSpecialMethodCall(
1599+
MemberExprBases[NumBases - 2], FinalizeMethod, FD);
1600+
FinalizeStmts.push_back(FinalizeCall);
1601+
}
15891602
return true;
15901603
}
15911604

@@ -1796,7 +1809,7 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
17961809
cast<ClassTemplateSpecializationDecl>(FieldTy->getAsRecordDecl())
17971810
->getTemplateInstantiationArgs();
17981811
assert(TemplateArgs.size() == 2 &&
1799-
"Incorrect template args for Accessor Type");
1812+
"Incorrect template args for spec constant type");
18001813
// Get specialization constant ID type, which is the second template
18011814
// argument.
18021815
QualType SpecConstIDTy = TemplateArgs.get(1).getAsType().getCanonicalType();
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -ast-dump %s | FileCheck %s
2+
3+
// This test checks that compiler generates correct initialization for spec
4+
// constants
5+
6+
#include <sycl.hpp>
7+
8+
struct SpecConstantsWrapper {
9+
cl::sycl::experimental::spec_constant<int, class sc_name1> SC1;
10+
cl::sycl::experimental::spec_constant<int, class sc_name2> SC2;
11+
};
12+
13+
int main() {
14+
cl::sycl::experimental::spec_constant<char, class MyInt32Const> SC;
15+
SpecConstantsWrapper W;
16+
cl::sycl::kernel_single_task<class kernel_sc>(
17+
[=]() {
18+
(void)SC;
19+
(void)W;
20+
});
21+
}
22+
23+
// CHECK: FunctionDecl {{.*}}kernel_sc{{.*}} 'void ()'
24+
// CHECK: VarDecl {{.*}}'(lambda at {{.*}}'
25+
// CHECK-NEXT: InitListExpr {{.*}}'(lambda at {{.*}}'
26+
// CHECK-NEXT: CXXConstructExpr {{.*}}'cl::sycl::experimental::spec_constant<char, class MyInt32Const>':'cl::sycl::experimental::spec_constant<char, MyInt32Const>'
27+
// CHECK-NEXT: InitListExpr {{.*}} 'SpecConstantsWrapper'
28+
// CHECK-NEXT: CXXConstructExpr {{.*}} 'cl::sycl::experimental::spec_constant<int, class sc_name1>':'cl::sycl::experimental::spec_constant<int, sc_name1>'
29+
// CHECK-NEXT: CXXConstructExpr {{.*}} 'cl::sycl::experimental::spec_constant<int, class sc_name2>':'cl::sycl::experimental::spec_constant<int, sc_name2>'

sycl/include/CL/sycl/experimental/spec_constant.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ template <typename T, typename ID = T> class spec_constant {
3232
private:
3333
// Implementation defined constructor.
3434
#ifdef __SYCL_DEVICE_ONLY__
35+
public:
3536
spec_constant() {}
37+
38+
private:
3639
#else
3740
spec_constant(T Cst) : Val(Cst) {}
3841
#endif

sycl/test/spec_const/spec_const_hw.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ float foo(
3939
return f32;
4040
}
4141

42+
struct SCWrapper {
43+
SCWrapper(cl::sycl::program &p)
44+
: SC1(p.set_spec_constant<class sc_name1, int>(4)),
45+
SC2(p.set_spec_constant<class sc_name2, int>(2)) {}
46+
47+
cl::sycl::experimental::spec_constant<int, class sc_name1> SC1;
48+
cl::sycl::experimental::spec_constant<int, class sc_name2> SC2;
49+
};
50+
4251
int main(int argc, char **argv) {
4352
val = argc + 16;
4453

@@ -61,6 +70,7 @@ int main(int argc, char **argv) {
6170
std::cout << "val = " << val << "\n";
6271
cl::sycl::program program1(q.get_context());
6372
cl::sycl::program program2(q.get_context());
73+
cl::sycl::program program3(q.get_context());
6474

6575
int goldi = (int)get_value();
6676
// TODO make this floating point once supported by the compiler
@@ -77,11 +87,17 @@ int main(int argc, char **argv) {
7787
// SYCL RT execution path
7888
program2.build_with_kernel_type<KernelBBBf>("-cl-fast-relaxed-math");
7989

90+
SCWrapper W(program3);
91+
program3.build_with_kernel_type<class KernelWrappedSC>();
92+
int goldw = 6;
93+
8094
std::vector<int> veci(1);
8195
std::vector<float> vecf(1);
96+
std::vector<int> vecw(1);
8297
try {
8398
cl::sycl::buffer<int, 1> bufi(veci.data(), veci.size());
8499
cl::sycl::buffer<float, 1> buff(vecf.data(), vecf.size());
100+
cl::sycl::buffer<int, 1> bufw(vecw.data(), vecw.size());
85101

86102
q.submit([&](cl::sycl::handler &cgh) {
87103
auto acci = bufi.get_access<cl::sycl::access::mode::write>(cgh);
@@ -99,6 +115,13 @@ int main(int argc, char **argv) {
99115
accf[0] = foo(f32);
100116
});
101117
});
118+
119+
q.submit([&](cl::sycl::handler &cgh) {
120+
auto accw = bufw.get_access<cl::sycl::access::mode::write>(cgh);
121+
cgh.single_task<KernelWrappedSC>(
122+
program3.get_kernel<KernelWrappedSC>(),
123+
[=]() { accw[0] = W.SC1.get() + W.SC2.get(); });
124+
});
102125
} catch (cl::sycl::exception &e) {
103126
std::cout << "*** Exception caught: " << e.what() << "\n";
104127
return 1;
@@ -116,6 +139,12 @@ int main(int argc, char **argv) {
116139
std::cout << "*** ERROR: " << valf << " != " << goldf << "(gold)\n";
117140
passed = false;
118141
}
142+
int valw = vecw[0];
143+
144+
if (valw != goldw) {
145+
std::cout << "*** ERROR: " << valw << " != " << goldw << "(gold)\n";
146+
passed = false;
147+
}
119148
std::cout << (passed ? "passed\n" : "FAILED\n");
120149
return passed ? 0 : 1;
121150
}

0 commit comments

Comments
 (0)