Skip to content

[WIP][SYCL] Add support for union types as kernel arguments #2255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,13 @@ class KernelObjVisitor {
else if (ElementTy->isStructureOrClassType())
VisitRecord(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(),
handlers...);
else if (ElementTy->isUnionType())
// TODO: This check is still necessary I think?! Array seems to handle
// this differently (see above) for structs I think.
//if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) {
VisitUnion(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(),
handlers...);
//}
else if (ElementTy->isArrayType())
VisitArrayElements(ArrayField, ElementTy, handlers...);
else if (ElementTy->isScalarType())
Expand Down Expand Up @@ -857,6 +864,41 @@ class KernelObjVisitor {
void VisitRecord(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper, Handlers &... handlers);

// Base case, only calls these when filtered.
template <typename... FilteredHandlers, typename ParentTy>
void VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper,
FilteredHandlers &... handlers) {
(void)std::initializer_list<int>{
(handlers.enterUnion(Owner, Parent), 0)...};
VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...);
(void)std::initializer_list<int>{
(handlers.leaveUnion(Owner, Parent), 0)...};
}


template <typename... FilteredHandlers, typename ParentTy,
typename CurHandler, typename... Handlers>
std::enable_if_t<!CurHandler::VisitUnionBody>
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper,
FilteredHandlers &... filtered_handlers,
CurHandler &cur_handler, Handlers &... handlers) {
VisitUnion<FilteredHandlers...>(
Owner, Parent, Wrapper, filtered_handlers..., handlers...);
}

template <typename... FilteredHandlers, typename ParentTy,
typename CurHandler, typename... Handlers>
std::enable_if_t<CurHandler::VisitUnionBody>
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper,
FilteredHandlers &... filtered_handlers,
CurHandler &cur_handler, Handlers &... handlers) {
VisitUnion<FilteredHandlers..., CurHandler>(
Owner, Parent, Wrapper, filtered_handlers..., cur_handler, handlers...);
}

template <typename... Handlers>
void VisitRecordHelper(const CXXRecordDecl *Owner,
clang::CXXRecordDecl::base_class_const_range Range,
Expand Down Expand Up @@ -942,6 +984,11 @@ class KernelObjVisitor {
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
VisitRecord(Owner, Field, RD, handlers...);
}
} else if (FieldTy->isUnionType()) {
if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) {
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
VisitUnion(Owner, Field, RD, handlers...);
}
} else if (FieldTy->isReferenceType())
KF_FOR_EACH(handleReferenceType, Field, FieldTy);
else if (FieldTy->isPointerType())
Expand Down Expand Up @@ -1005,6 +1052,7 @@ class SyclKernelFieldHandler {
}
virtual bool handleSyclHalfType(FieldDecl *, QualType) { return true; }
virtual bool handleStructType(FieldDecl *, QualType) { return true; }
virtual bool handleUnionType(FieldDecl *, QualType) { return true; }
virtual bool handleReferenceType(FieldDecl *, QualType) { return true; }
virtual bool handlePointerType(FieldDecl *, QualType) { return true; }
virtual bool handleArrayType(FieldDecl *, QualType) { return true; }
Expand All @@ -1024,6 +1072,8 @@ class SyclKernelFieldHandler {
virtual bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) {
return true;
}
virtual bool enterUnion(const CXXRecordDecl *, FieldDecl *) { return true; }
virtual bool leaveUnion(const CXXRecordDecl *, FieldDecl *) { return true; }

// The following are used for stepping through array elements.

Expand Down Expand Up @@ -1201,6 +1251,65 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
}
};

// A type to check the validity of passing union with accessor/sampler/stream
// member as a kernel argument types.
class SyclKernelUnionBodyChecker : public SyclKernelFieldHandler {
static constexpr const bool VisitUnionBody = true;
int UnionCount = 0;
bool IsInvalid = false;
DiagnosticsEngine &Diag;

public:
SyclKernelUnionBodyChecker(Sema &S)
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
bool isValid() { return !IsInvalid; }

bool enterUnion(const CXXRecordDecl *RD, FieldDecl *FD) {
++UnionCount;
return true;
}

bool leaveUnion(const CXXRecordDecl *RD, FieldDecl *FD) {
--UnionCount;
return true;
}

bool handlePointerType(FieldDecl *FD, QualType FieldTy) final {
if (UnionCount) {
IsInvalid = true;
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
<< FieldTy;
}
return isValid();
}

bool handleSyclAccessorType(FieldDecl *FD, QualType FieldTy) final {
if (UnionCount) {
IsInvalid = true;
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
<< FieldTy;
}
return isValid();
}

bool handleSyclSamplerType(FieldDecl *FD, QualType FieldTy) final {
if (UnionCount) {
IsInvalid = true;
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
<< FieldTy;
}
return isValid();
}
bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final {
if (UnionCount) {
IsInvalid = true;
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
<< FieldTy;
}
return isValid();
}
};

// A type to Create and own the FunctionDecl for the kernel.
class SyclKernelDeclCreator : public SyclKernelFieldHandler {
FunctionDecl *KernelDecl;
Expand Down Expand Up @@ -1416,6 +1525,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
return true;
}

bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
return handleScalarType(FD, FieldTy);
}

bool handleSyclHalfType(FieldDecl *FD, QualType FieldTy) final {
addParam(FD, FieldTy);
return true;
Expand Down Expand Up @@ -1751,6 +1864,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
return true;
}

bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
return handleScalarType(FD, FieldTy);
}

bool enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
CXXCastPath BasePath;
QualType DerivedTy(RD->getTypeForDecl(), 0);
Expand Down Expand Up @@ -1955,6 +2072,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
return true;
}

bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
return handleScalarType(FD, FieldTy);
}

bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final {
addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
return true;
Expand Down