Skip to content

[SYCL] Fix for detection of free function calls. #3003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2705,23 +2705,18 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {

// Sets a flag if the kernel is a parallel_for that calls the
// free function API "this_item".
void setThisItemIsCalled(const CXXRecordDecl *KernelObj,
FunctionDecl *KernelFunc) {
void setThisItemIsCalled(FunctionDecl *KernelFunc) {
if (getKernelInvocationKind(KernelFunc) != InvokeParallelFor)
return;

const CXXMethodDecl *WGLambdaFn = getOperatorParens(KernelObj);
if (!WGLambdaFn)
return;

// The call graph for this translation unit.
CallGraph SYCLCG;
SYCLCG.addToCallGraph(SemaRef.getASTContext().getTranslationUnitDecl());
using ChildParentPair =
std::pair<const FunctionDecl *, const FunctionDecl *>;
llvm::SmallPtrSet<const FunctionDecl *, 16> Visited;
llvm::SmallVector<ChildParentPair, 16> WorkList;
WorkList.push_back({WGLambdaFn, nullptr});
WorkList.push_back({KernelFunc, nullptr});

while (!WorkList.empty()) {
const FunctionDecl *FD = WorkList.back().first;
Expand Down Expand Up @@ -2772,7 +2767,7 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
bool IsSIMDKernel = isESIMDKernelType(KernelObj);
Header.startKernel(Name, NameType, StableName, KernelObj->getLocation(),
IsSIMDKernel);
setThisItemIsCalled(KernelObj, KernelFunc);
setThisItemIsCalled(KernelFunc);
}

bool handleSyclAccessorType(const CXXRecordDecl *RD,
Expand Down
25 changes: 22 additions & 3 deletions clang/test/CodeGenSYCL/parallel_for_this_item.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3EMU",
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3OWL",
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3RAT",
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3FOX"
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3FOX",
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3BEE"
// CHECK-NEXT: };

// CHECK:template <> struct KernelInfo<class GNU> {
Expand Down Expand Up @@ -97,6 +98,22 @@
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsAnyThisFreeFunction() { return 1; }
// CHECK-NEXT:};
// CHECK-NEXT:template <> struct KernelInfo<class BEE> {
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr const char* getName() { return "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3BEE"; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr unsigned getNumParams() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr const kernel_param_desc_t& getParamDesc(unsigned i) {
// CHECK-NEXT: return kernel_signatures[i+0];
// CHECK-NEXT: }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool isESIMD() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsThisItem() { return 1; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsAnyThisFreeFunction() { return 1; }
// CHECK-NEXT:};

#include "sycl.hpp"

Expand Down Expand Up @@ -135,8 +152,10 @@ int main() {
cgh.parallel_for<class RAT>(range<1>(1), [=](id<1> I) { f(); });

// This kernel does not call sycl::this_item, but does call this_id
cgh.parallel_for<class FOX>(range<1>(1),
[=](id<1> I) { this_id<1>(); });
cgh.parallel_for<class FOX>(range<1>(1), [=](id<1> I) { this_id<1>(); });

// This kernel calls sycl::this_item
cgh.parallel_for<class BEE>(range<1>(1), [=](auto I) { this_item<1>(); });
Copy link
Contributor

@elizabethandrews elizabethandrews Jan 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. So getOperatorParens(KernelObj) doesn't return operator method when auto is used? Do you understand why? It is not obvious to me why. @premanandrao could you also please take a look?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdeodhar, is this the case that was failing before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this particular construct in an existing runtime test was failing. I have replicated the situation in the existing clang lit test. Owing to some templating of members, the operator() of the KernelObj was not being found. However, the code looking for the operator() is unnecessary because the kernel function is already available.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... That means KernelObj is defined differently when parameter type is auto? Do you know why this happens? I'm approving the patch since the answer to this question is probably outside the scope of this patch. However, if you know the answer I am curious :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, there is likely value to instead fix the poorly named getOperatorParens instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erichkeane would you know why its getting wrapped?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erichkeane well that's a fun case -- I think you can use decls_begin() and decls_end() to do this:

static const CXXMethodDecl *getOperatorParens(const CXXRecordDecl *Rec) {
  auto MethIter = llvm::find_if(Rec->methods(), [](const CXXMethodDecl *MD) {
    return MD->getOverloadedOperator() == OO_Call;
  });
  if (MethIter != Rec->methods_end())
   return *MethIter;

  using function_template_iterator = specific_decl_iterator<FunctionTemplateDecl>;
  using function_template_range = llvm::iterator_range<specific_decl_iterator<FunctionTemplateDecl>>;
  auto FTIter = llvm::find_if(function_template_range(function_template_iterator(Rec->decls_begin()),
                                                      function_template_iterator(Rec->decls_end())),
                                           [](const FunctionTemplateDecl *FTD) {
                                             if (const auto *MD = dyn_cast<CXXMethodDecl>(FTD->getTemplatedDecl()))
                                               return MD->getOverloadedOperator() == OO_Call;
                                             return false;
                                           });
  if (FTIter != function_template_iterator(Rec->decls_end()))
   return *FTIter;
  return nullptr;
}

(This is totally untested and was written in a web browser, so YMMV.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erichkeane would you know why its getting wrapped?

Because we ALWAYS wrap templates. We need something in the AST to contain BOTH the uninstantiated and instantiated template.

We should probably update the getOperatorParens, but, IMO, it should probably return an array of valid (non-dependent) operators. I'm guessing this whole idea of using getOperatorParens is flawed though. Trying to figure out the call operator of the type rather than checking the callgraph is likely the wrong way about it. @kbobrovs : Please take a look at how you're using getOperatorParens and make sure it would work in a case where there is multiple operator()s, including instantiated template versions.

IN the case of what @AaronBallman just posted (thanks btw!), I don't know if that works for 2 reasons: First, I think the return *FTIter; line is wrong, since that returns a FunctionTemplateDecl*, instead of a CXXMethodDecl*. Second, it has the same problem of only returning the FIRST operator () :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, that should be return cast<CXXMethodDecl>(*FTIter) instead, and it definitely only finds the first. As for how to handle multiple operator() definitions -- I agree that the function should either return a container of all the operators or take the container as a parameter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kbobrovs : Please take a look at how you're using getOperatorParens and make sure it would work in a case where there is multiple operator()s, including instantiated template versions.

AFAIR, this simplistic approach was based on assumption that the API code is under our control, and we know that there are no other operator '()' calls from the parallel_for_kernel.

});

return 0;
Expand Down