Skip to content

Kokkos ensure kokkos function #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ find_package(Clang REQUIRED)
add_library(KokkosClangTidyModule MODULE
src/KokkosTidyModule.cpp

src/EnsureKokkosFunctionCheck.cpp
src/EnsureKokkosFunctionCheck.h
src/ImplicitThisCaptureCheck.cpp
src/ImplicitThisCaptureCheck.h
src/KokkosMatchers.cpp
Expand Down
7 changes: 7 additions & 0 deletions docs/kokkos-ensure-kokkos-function.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. title:: clang-tidy - kokkos-ensure-kokkos-function

kokkos-ensure-kokkos-function
=============================

This check ensures that the user has annotated functions called by Kokkos with
one of the KOKKOS_FUNCTION style annotations.
228 changes: 228 additions & 0 deletions src/EnsureKokkosFunctionCheck.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include "EnsureKokkosFunctionCheck.h"
#include "KokkosMatchers.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"

using namespace clang::ast_matchers;

namespace clang {
namespace tidy {
namespace kokkos {
namespace {

std::string KF_Regex = "KOKKOS_.*FUNCTION"; // NOLINT

auto notKFunc(std::string const &AllowedFuncRegex) {
auto AllowedFuncMatch = unless(matchesName(AllowedFuncRegex));
return functionDecl(unless(matchesAttr(KF_Regex)),
unless(isExpansionInSystemHeader()), AllowedFuncMatch);
}

bool isAnnotated(CXXMethodDecl const *Method) {
// If the method is annotated the match will not be empty
return !match(cxxMethodDecl(matchesAttr(KF_Regex)), *Method,
Method->getASTContext())
.empty();
}

// TODO one day we might want to check if the lambda is local to our current
// function context, but until someone complains that's a lot of work. The
// other case we aren't going to deal with is: void foo(){ struct S { static
// void func(){} }; S::func(); }
bool callExprIsToLambaOp(CallExpr const *CE) {
if (auto const *CMD =
dyn_cast_or_null<CXXMethodDecl>(CE->getDirectCallee())) {
if (auto const *Parent = CMD->getParent()) {
if (Parent->isLambda()) {
return true;
}
}
}
return false;
}

auto checkLambdaBody(CXXRecordDecl const *Lambda,
std::string const &AllowedFuncRegex) {
assert(Lambda->isLambda());
llvm::SmallPtrSet<CallExpr const *, 1> BadCallSet;
auto const *FD = Lambda->getLambdaCallOperator();
if (!FD) {
return BadCallSet;
}

auto notKCalls = // NOLINT
callExpr(callee(notKFunc(AllowedFuncRegex))).bind("CE");

auto BadCalls = match(functionDecl(forEachDescendant(notKCalls)), *FD,
FD->getASTContext());

for (auto BadCall : BadCalls) {
auto const *CE = BadCall.getNodeAs<CallExpr>("CE");
if (callExprIsToLambaOp(CE)) { // function call handles nullptr
continue;
}

BadCallSet.insert(CE);
}

return BadCallSet;
}

// Recurses through the tree of all calls to functions with visble bodies
void recurseCallExpr(
llvm::SmallPtrSet<CXXMethodDecl const *, 8> const &FunctorMethods,
CallExpr const *Call,
llvm::SmallPtrSet<CXXMethodDecl const *, 4> &Results) {

// Get the body of the called function
auto const *CallDecl = Call->getCalleeDecl();
if (CallDecl == nullptr || !CallDecl->hasBody()) {
return;
}

auto &ASTContext = CallDecl->getASTContext();

// Check if the called function is a member function of the functor
// if yes then write the result back out.
if (auto const *Method = dyn_cast<CXXMethodDecl>(CallDecl)) {
if (FunctorMethods.count(Method) > 0) {
Results.insert(Method);
}
}

// Match all callexprs in our body
auto CEs = match(compoundStmt(forEachDescendant(callExpr().bind("CE"))),
*(CallDecl->getBody()), ASTContext);

// Check all those calls for uses of members of the functor as well
for (auto BN : CEs) {
if (auto const *CE = BN.getNodeAs<CallExpr>("CE")) {
recurseCallExpr(FunctorMethods, CE, Results);
}
}
}

// Find methods from our functor called in the tree of Kokkos::parallel_x
auto checkFunctorBody(CXXRecordDecl const *Functor, CallExpr const *CallSite) {
llvm::SmallPtrSet<CXXMethodDecl const *, 8> FunctorMethods;
for (auto const *Method : Functor->methods()) {
FunctorMethods.insert(Method);
}
llvm::SmallPtrSet<CXXMethodDecl const *, 4> Results;
recurseCallExpr(FunctorMethods, CallSite, Results);

return Results;
}

} // namespace

EnsureKokkosFunctionCheck::EnsureKokkosFunctionCheck(StringRef Name,
ClangTidyContext *Context)
: ClangTidyCheck(Name, Context) {
AllowIfExplicitHost = std::stoi(Options.get("AllowIfExplicitHost", "0").str());
AllowedFunctionsRegex = Options.get("AllowedFunctionsRegex", "a^");
// This can't be empty because the regex ast matchers assert !empty
assert(!AllowedFunctionsRegex.empty());
}

void EnsureKokkosFunctionCheck::storeOptions(
ClangTidyOptions::OptionMap &Opts) {
Options.store(Opts, "AllowedFunctionsRegex", AllowedFunctionsRegex);
Options.store(Opts, "AllowIfExplicitHost",
std::to_string(AllowIfExplicitHost));
}

void EnsureKokkosFunctionCheck::registerMatchers(MatchFinder *Finder) {
auto notKCalls = // NOLINT
callExpr(callee(notKFunc(AllowedFunctionsRegex))).bind("CE");

// We have to be sure that we don't match functionDecls in systems headers,
// because they might call our Functor, which if it is a lambda will not be
// marked with KOKKOS_FUNCTION
Finder->addMatcher(functionDecl(matchesAttr(KF_Regex),
unless(isExpansionInSystemHeader()),
forEachDescendant(notKCalls))
.bind("ParentFD"),
this);

// Need to check the Functor also
auto Functor = expr(hasType(
cxxRecordDecl(unless(isExpansionInSystemHeader())).bind("Functor")));
Finder->addMatcher(callExpr(isKokkosParallelCall(), hasAnyArgument(Functor))
.bind("KokkosCE"),
this);
}

void EnsureKokkosFunctionCheck::check(const MatchFinder::MatchResult &Result) {

auto const *ParentFD = Result.Nodes.getNodeAs<FunctionDecl>("ParentFD");
auto const *CE = Result.Nodes.getNodeAs<CallExpr>("CE");
auto const *Functor = Result.Nodes.getNodeAs<CXXRecordDecl>("Functor");

if (ParentFD != nullptr) {
if (callExprIsToLambaOp(CE)) { // Avoid false positives for local lambdas
return;
}

diag(CE->getBeginLoc(),
"function %0 called in %1 is missing a KOKKOS_X_FUNCTION annotation")
<< CE->getDirectCallee() << ParentFD;
diag(CE->getDirectCallee()->getLocation(), "Function %0 declared here",
DiagnosticIDs::Note)
<< CE->getDirectCallee();
}

if (Functor != nullptr) {
auto const *CE = Result.Nodes.getNodeAs<CallExpr>("KokkosCE");
if (AllowIfExplicitHost != 0 && explicitlyDefaultHostExecutionSpace(CE)) {
return;
}

if (Functor->isLambda()) {
auto BadCalls = checkLambdaBody(Functor, AllowedFunctionsRegex);
for (auto const *BadCall : BadCalls) {
diag(BadCall->getBeginLoc(),
"Function %0 called in a lambda was missing "
"KOKKOS_X_FUNCTION annotation.")
<< BadCall->getDirectCallee();
diag(BadCall->getDirectCallee()->getBeginLoc(),
"Function %0 was delcared here", DiagnosticIDs::Note)
<< BadCall->getDirectCallee();
}
} else {
for (auto const *CalledMethod : checkFunctorBody(Functor, CE)) {
if (isAnnotated(CalledMethod)) {
continue;
}

diag(CE->getBeginLoc(),
"Called a member function of %0 that requires a "
"KOKKOS_X_FUNCTION annotation.")
<< CalledMethod->getParent();
diag(CalledMethod->getBeginLoc(),
"Member Function %0 of %1 was delcared here", DiagnosticIDs::Note)
<< CalledMethod << CalledMethod->getParent();
}
}
}
}

} // namespace kokkos
} // namespace tidy
} // namespace clang
43 changes: 43 additions & 0 deletions src/EnsureKokkosFunctionCheck.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_KOKKOS_ENSUREKOKKOSFUNCTIONCHECK_H
#define LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_KOKKOS_ENSUREKOKKOSFUNCTIONCHECK_H

#include "clang-tidy/ClangTidyCheck.h"

namespace clang {
namespace tidy {
namespace kokkos {

/// Check that ensures user provided functions were properly annotated
class EnsureKokkosFunctionCheck : public ClangTidyCheck {
public:
EnsureKokkosFunctionCheck(StringRef Name, ClangTidyContext *Context);
void registerMatchers(ast_matchers::MatchFinder *Finder) override;
void check(const ast_matchers::MatchFinder::MatchResult &Result) override;
void storeOptions(ClangTidyOptions::OptionMap &Opts) override;

private:
std::string AllowedFunctionsRegex;
int AllowIfExplicitHost;
};

} // namespace kokkos
} // namespace tidy
} // namespace clang

#endif // LLVM_CLANG_TOOLS_EXTRA_CLANG_TIDY_KOKKOS_ENSUREKOKKOSFUNCTIONCHECK_H
15 changes: 4 additions & 11 deletions src/ImplicitThisCaptureCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,12 @@ std::optional<SourceLocation> capturesThis(CXXRecordDecl const *CRD) {
ImplicitThisCaptureCheck::ImplicitThisCaptureCheck(StringRef Name,
ClangTidyContext *Context)
: ClangTidyCheck(Name, Context) {
CheckIfExplicitHost = std::stoi(Options.get("CheckIfExplicitHost", "0").str());
HostTypeDefRegex =
Options.get("HostTypeDefRegex", "Kokkos::DefaultHostExecutionSpace");
AllowIfExplicitHost = std::stoi(Options.get("AllowIfExplicitHost", "0").str());
}

void ImplicitThisCaptureCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
Options.store(Opts, "CheckIfExplicitHost",
std::to_string(CheckIfExplicitHost));
Options.store(Opts, "HostTypeDefRegex", HostTypeDefRegex);
Options.store(Opts, "AllowIfExplicitHost",
std::to_string(AllowIfExplicitHost));
}

void ImplicitThisCaptureCheck::registerMatchers(MatchFinder *Finder) {
Expand All @@ -74,11 +71,7 @@ void ImplicitThisCaptureCheck::registerMatchers(MatchFinder *Finder) {
void ImplicitThisCaptureCheck::check(const MatchFinder::MatchResult &Result) {
auto const *CE = Result.Nodes.getNodeAs<CallExpr>("x");

if (CheckIfExplicitHost) {
if (explicitlyUsingHostExecutionSpace(CE, HostTypeDefRegex)) {
return;
}
}
AllowIfExplicitHost = std::stoi(Options.get("AllowIfExplicitHost", "0").str());

auto const *Lambda = Result.Nodes.getNodeAs<CXXRecordDecl>("Lambda");
auto CaptureLocation = capturesThis(Lambda);
Expand Down
3 changes: 1 addition & 2 deletions src/ImplicitThisCaptureCheck.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class ImplicitThisCaptureCheck : public ClangTidyCheck {
void storeOptions(ClangTidyOptions::OptionMap &Opts) override;
void check(const ast_matchers::MatchFinder::MatchResult &Result) override;
private:
int CheckIfExplicitHost;
std::string HostTypeDefRegex;
int AllowIfExplicitHost;
};

} // namespace kokkos
Expand Down
Loading