Skip to content

[offload][SYCL] Add SYCL Module splitting. #131347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions llvm/include/llvm/Transforms/Utils/SYCLUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//===------------ SYCLUtils.h - SYCL utility functions --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Utility functions for SYCL.
//===----------------------------------------------------------------------===//
#ifndef LLVM_FRONTEND_SYCL_UTILS_H
#define LLVM_FRONTEND_SYCL_UTILS_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"

#include <optional>
#include <string>

namespace llvm {

class Module;
class Function;
class raw_ostream;

namespace sycl {

enum class IRSplitMode {
IRSM_PER_TU, // one module per translation unit
IRSM_PER_KERNEL, // one module per kernel
IRSM_NONE // no splitting
};

/// \returns IRSplitMode value if \p S is recognized. Otherwise, std::nullopt is
/// returned.
std::optional<IRSplitMode> convertStringToSplitMode(StringRef S);

/// FunctionCategorizer used for splitting in SYCL compilation flow.
class FunctionCategorizer {
public:
FunctionCategorizer(IRSplitMode SM);

FunctionCategorizer() = delete;
FunctionCategorizer(FunctionCategorizer &) = delete;
FunctionCategorizer &operator=(const FunctionCategorizer &) = delete;
FunctionCategorizer(FunctionCategorizer &&) = default;
FunctionCategorizer &operator=(FunctionCategorizer &&) = default;

/// Returns integer specifying the category for the entry point.
/// If the given function isn't an entry point then returns std::nullopt.
std::optional<int> operator()(const Function &F);

private:
struct KeyInfo {
static SmallString<0> getEmptyKey() { return SmallString<0>(""); }

static SmallString<0> getTombstoneKey() { return SmallString<0>("-"); }

static bool isEqual(const SmallString<0> &LHS, const SmallString<0> &RHS) {
return LHS == RHS;
}

static unsigned getHashValue(const SmallString<0> &S) {
return llvm::hash_value(StringRef(S));
}
};

IRSplitMode SM;
DenseMap<SmallString<0>, int, KeyInfo> StrKeyToID;
};

/// The structure represents a LLVM Module accompanied by additional
/// information. Split Modules are being stored at disk due to the high RAM
/// consumption during the whole splitting process.
struct ModuleAndSYCLMetadata {
std::string ModuleFilePath;
std::string Symbols;

ModuleAndSYCLMetadata() = delete;
ModuleAndSYCLMetadata(const ModuleAndSYCLMetadata &) = default;
ModuleAndSYCLMetadata &operator=(const ModuleAndSYCLMetadata &) = default;
ModuleAndSYCLMetadata(ModuleAndSYCLMetadata &&) = default;
ModuleAndSYCLMetadata &operator=(ModuleAndSYCLMetadata &&) = default;

ModuleAndSYCLMetadata(const Twine &File, std::string Symbols)
: ModuleFilePath(File.str()), Symbols(std::move(Symbols)) {}
};

std::string makeSymbolTable(const Module &M);

using StringTable = SmallVector<SmallVector<SmallString<64>>>;

void writeStringTable(const StringTable &Table, raw_ostream &OS);

} // namespace sycl
} // namespace llvm

#endif // LLVM_FRONTEND_SYCL_UTILS_H
42 changes: 42 additions & 0 deletions llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//===-------- SplitModuleByCategory.h - module split ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Functionality to split a module by categories.
//===----------------------------------------------------------------------===//

#ifndef LLVM_FRONTEND_SYCL_SPLIT_MODULE_H
#define LLVM_FRONTEND_SYCL_SPLIT_MODULE_H

#include "llvm/ADT/STLFunctionalExtras.h"

#include <memory>
#include <optional>
#include <string>

namespace llvm {

class Module;
class Function;

namespace sycl {

/// FunctionCategorizer returns integer category for the given Function.
/// Otherwise, it returns std::nullopt if function doesn't have a category.
using FunctionCategorizer = function_ref<std::optional<int>(const Function &F)>;

using PostSplitCallbackType = function_ref<void(std::unique_ptr<Module> Part)>;

/// Splits the given module \p M.
/// Every split image is being passed to \p Callback for further possible
/// processing.
void splitModuleByCategory(std::unique_ptr<Module> M, FunctionCategorizer FC,
PostSplitCallbackType Callback);

} // namespace sycl
} // namespace llvm

#endif // LLVM_FRONTEND_SYCL_SPLIT_MODULE_H
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ add_llvm_component_library(LLVMTransformUtils
SimplifyLibCalls.cpp
SizeOpts.cpp
SplitModule.cpp
SplitModuleByCategory.cpp
StripNonLineTableDebugInfo.cpp
SYCLUtils.cpp
SymbolRewriter.cpp
UnifyFunctionExitNodes.cpp
UnifyLoopExits.cpp
Expand Down
117 changes: 117 additions & 0 deletions llvm/lib/Transforms/Utils/SYCLUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
//===------------ SYCLUtils.cpp - SYCL utility functions ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// SYCL utility functions.
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Utils/SYCLUtils.h"

#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/raw_ostream.h"

#include <optional>

using namespace llvm;
using namespace sycl;

namespace {

SmallString<0> computeFunctionCategoryForSplitting(IRSplitMode SM,
const Function &F) {
static constexpr char ATTR_SYCL_MODULE_ID[] = "sycl-module-id";
SmallString<0> Key;
switch (SM) {
case IRSplitMode::IRSM_PER_KERNEL:
Key = F.getName().str();
break;
case IRSplitMode::IRSM_PER_TU:
Key = F.getFnAttribute(ATTR_SYCL_MODULE_ID).getValueAsString().str();
break;
default:
llvm_unreachable("other modes aren't expected");
}

return Key;
}

bool isKernel(const Function &F) {
return F.getCallingConv() == CallingConv::SPIR_KERNEL ||
F.getCallingConv() == CallingConv::AMDGPU_KERNEL ||
F.getCallingConv() == CallingConv::PTX_Kernel;
}

bool isEntryPoint(const Function &F) {
// Skip declarations, if any: they should not be included into a vector of
// entry points groups or otherwise we will end up with incorrectly generated
// list of symbols.
if (F.isDeclaration())
return false;

// Kernels are always considered to be entry points
return isKernel(F);
}

} // anonymous namespace

namespace llvm {
namespace sycl {

std::optional<IRSplitMode> convertStringToSplitMode(StringRef S) {
static const StringMap<IRSplitMode> Values = {
{"source", IRSplitMode::IRSM_PER_TU},
{"kernel", IRSplitMode::IRSM_PER_KERNEL},
{"none", IRSplitMode::IRSM_NONE}};

auto It = Values.find(S);
if (It == Values.end())
return std::nullopt;

return It->second;
}

FunctionCategorizer::FunctionCategorizer(IRSplitMode SM) : SM(SM) {
if (SM == IRSplitMode::IRSM_NONE)
llvm_unreachable("FunctionCategorizer isn't supported to none splitting.");
}

std::optional<int> FunctionCategorizer::operator()(const Function &F) {
if (!isEntryPoint(F))
return std::nullopt; // skip the function.

auto StringKey = computeFunctionCategoryForSplitting(SM, F);
if (auto it = StrKeyToID.find(StringRef(StringKey)); it != StrKeyToID.end())
return it->second;

int ID = static_cast<int>(StrKeyToID.size());
return StrKeyToID.try_emplace(std::move(StringKey), ID).first->second;
}

std::string makeSymbolTable(const Module &M) {
SmallString<0> Data;
raw_svector_ostream OS(Data);
for (const auto &F : M)
if (isEntryPoint(F))
OS << F.getName() << '\n';

return std::string(OS.str());
}

void writeStringTable(const StringTable &Table, raw_ostream &OS) {
assert(!Table.empty() && "table should contain at least column titles");
assert(!Table[0].empty() && "table should be non-empty");
OS << '[' << join(Table[0].begin(), Table[0].end(), "|") << "]\n";
for (size_t I = 1, E = Table.size(); I != E; ++I) {
assert(Table[I].size() == Table[0].size() && "row's size should be equal");
OS << join(Table[I].begin(), Table[I].end(), "|") << '\n';
}
}

} // namespace sycl
} // namespace llvm
Loading
Loading