Skip to content

Commit

Permalink
[SYCL] Don't use legacy ANSI-only Windows API for loading plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
againull committed Aug 17, 2023
1 parent 319f067 commit 3c334c2
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 38 deletions.
4 changes: 4 additions & 0 deletions sycl/include/sycl/detail/os_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <cstdlib> // for size_t
#include <string> // for string
#include <sys/stat.h> // for stat
#include <filesystem>

#ifdef _WIN32
#define __SYCL_RT_OS_WINDOWS
Expand Down Expand Up @@ -49,6 +50,9 @@ class __SYCL_EXPORT OSUtil {
/// Returns a directory component of a path.
static std::string getDirName(const char *Path);

/// Returns an absolute path to a directory where the object was found.
static std::filesystem::path getCurrentDSODirPath();

#ifdef __SYCL_RT_OS_WINDOWS
static constexpr const char *DirSep = "\\";
#else
Expand Down
5 changes: 3 additions & 2 deletions sycl/include/sycl/detail/pi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <string> // for char_traits, string
#include <type_traits> // for false_type, true_type
#include <vector> // for vector
#include <filesystem>

#ifdef XPTI_ENABLE_INSTRUMENTATION
// Forward declarations
Expand Down Expand Up @@ -171,7 +172,7 @@ __SYCL_EXPORT void contextSetExtendedDeleter(const sycl::context &constext,

// Function to load a shared library
// Implementation is OS dependent
void *loadOsLibrary(const std::string &Library);
void *loadOsLibrary(const std::filesystem::path &Library);

// Function to unload a shared library
// Implementation is OS dependent (see posix-pi.cpp and windows-pi.cpp)
Expand All @@ -180,7 +181,7 @@ int unloadOsLibrary(void *Library);
// Function to load the shared plugin library
// On Windows, this will have been pre-loaded by proxy loader.
// Implementation is OS dependent.
void *loadOsPluginLibrary(const std::string &Library);
void *loadOsPluginLibrary(const std::filesystem::path &Library);

// Function to unload the shared plugin library
// Implementation is OS dependent (see posix-pi.cpp and windows-pi.cpp)
Expand Down
64 changes: 42 additions & 22 deletions sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
// similar approach.

#include <cassert>
#include <filesystem>

#ifdef _WIN32

Expand Down Expand Up @@ -99,6 +100,25 @@ std::string getCurrentDSODir() {
return Path;
}

std::filesystem::path getCurrentDSODirPath() {
wchar_t Path[MAX_PATH];
//Path[0] = '\0';
//Path[sizeof(Path) - 1] = '\0';
auto Handle = getOSModuleHandle(reinterpret_cast<void *>(&getCurrentDSODir));
DWORD Ret = GetModuleFileName(
reinterpret_cast<HMODULE>(ExeModuleHandle == Handle ? 0 : Handle),
reinterpret_cast<LPWSTR>(&Path), sizeof(Path));
assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?");
assert(Ret > 0 && "GetModuleFileName failed");
(void)Ret;

BOOL RetCode = PathRemoveFileSpec(reinterpret_cast<LPWSTR>(&Path));
assert(RetCode && "PathRemoveFileSpec failed");
(void)RetCode;

return std::filesystem::path(std::wstring(Path));
}

// these are cribbed from include/sycl/detail/pi.hpp
// a new plugin must be added to both places.
#ifdef _MSC_VER
Expand All @@ -121,7 +141,7 @@ std::string getCurrentDSODir() {

// ------------------------------------

using MapT = std::map<std::string, void *>;
using MapT = std::map<std::filesystem::path, void *>;

MapT &getDllMap() {
static MapT dllMap;
Expand All @@ -141,58 +161,58 @@ void preloadLibraries() {
//
UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS);
// Exclude current directory from DLL search path
if (!SetDllDirectoryA("")) {
if (!SetDllDirectory(L"")) {
assert(false && "Failed to update DLL search path");
}

// this path duplicates sycl/detail/pi.cpp:initializePlugins
const std::string LibSYCLDir = getCurrentDSODir() + DirSep;
std::filesystem::path LibSYCLDir = getCurrentDSODirPath();

MapT &dllMap = getDllMap();

auto ocl_path = LibSYCLDir / __SYCL_OPENCL_PLUGIN_NAME;
dllMap.emplace(ocl_path, LoadLibrary(ocl_path.wstring().c_str()));

std::string ocl_path = LibSYCLDir + __SYCL_OPENCL_PLUGIN_NAME;
dllMap.emplace(ocl_path, LoadLibraryA(ocl_path.c_str()));

std::string l0_path = LibSYCLDir + __SYCL_LEVEL_ZERO_PLUGIN_NAME;
dllMap.emplace(l0_path, LoadLibraryA(l0_path.c_str()));
auto l0_path = LibSYCLDir / __SYCL_LEVEL_ZERO_PLUGIN_NAME;
dllMap.emplace(l0_path, LoadLibrary(l0_path.wstring().c_str()));

std::string cuda_path = LibSYCLDir + __SYCL_CUDA_PLUGIN_NAME;
dllMap.emplace(cuda_path, LoadLibraryA(cuda_path.c_str()));
auto cuda_path = LibSYCLDir / __SYCL_CUDA_PLUGIN_NAME;
dllMap.emplace(cuda_path, LoadLibrary(cuda_path.wstring().c_str()));

std::string esimd_path = LibSYCLDir + __SYCL_ESIMD_EMULATOR_PLUGIN_NAME;
dllMap.emplace(esimd_path, LoadLibraryA(esimd_path.c_str()));
auto esimd_path = LibSYCLDir / __SYCL_ESIMD_EMULATOR_PLUGIN_NAME;
dllMap.emplace(esimd_path, LoadLibrary(esimd_path.wstring().c_str()));

std::string hip_path = LibSYCLDir + __SYCL_HIP_PLUGIN_NAME;
dllMap.emplace(hip_path, LoadLibraryA(hip_path.c_str()));
auto hip_path = LibSYCLDir / __SYCL_HIP_PLUGIN_NAME;
dllMap.emplace(hip_path, LoadLibrary(hip_path.wstring().c_str()));

std::string ur_path = LibSYCLDir + __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME;
dllMap.emplace(ur_path, LoadLibraryA(ur_path.c_str()));
auto ur_path = LibSYCLDir / __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME;
dllMap.emplace(ur_path, LoadLibrary(ur_path.wstring().c_str()));

std::string nativecpu_path = LibSYCLDir + __SYCL_NATIVE_CPU_PLUGIN_NAME;
dllMap.emplace(nativecpu_path, LoadLibraryA(nativecpu_path.c_str()));
auto nativecpu_path = LibSYCLDir / __SYCL_NATIVE_CPU_PLUGIN_NAME;
dllMap.emplace(nativecpu_path, LoadLibrary(nativecpu_path.wstring().c_str()));

// Restore system error handling.
(void)SetErrorMode(SavedMode);
if (!SetDllDirectoryA(nullptr)) {
if (!SetDllDirectory(nullptr)) {
assert(false && "Failed to restore DLL search path");
}
}

/// windows_pi.cpp:loadOsPluginLibrary() calls this to get the DLL loaded
/// earlier.
__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) {
__declspec(dllexport) void *getPreloadedPlugin(const std::filesystem::path &PluginPath) {

MapT &dllMap = getDllMap();

auto match = dllMap.find(PluginPath); // result might be nullptr (not found),
// which is perfectly valid.
if (match == dllMap.end()) {
// unit testing? return nullptr (not found) rather than risk asserting below
if (PluginPath.find("unittests") != std::string::npos)
if (PluginPath.string().find("unittests") != std::string::npos)
return nullptr;

// Otherwise, asking for something we don't know about at all, is an issue.
std::cout << "unknown plugin: " << PluginPath << std::endl;
std::cout << "unknown plugin: " << PluginPath.string() << std::endl;
assert(false && "getPreloadedPlugin was given an unknown plugin path.");
return nullptr;
}
Expand Down
3 changes: 2 additions & 1 deletion sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#ifdef _WIN32
#include <string>
#include <filesystem>

__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath);
__declspec(dllexport) void *getPreloadedPlugin(const std::filesystem::path &PluginPath);
#endif
30 changes: 30 additions & 0 deletions sycl/source/detail/os_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <cassert>
#include <limits>
#include <filesystem>

#if defined(__SYCL_RT_OS_LINUX)

Expand Down Expand Up @@ -138,6 +139,11 @@ std::string OSUtil::getDirName(const char *Path) {
return Tmp;
}

std::filesystem::path OSUtil::getCurrentDSODirPath() {
return std::filesystem::path(OSUtil::getCurrentDSODir());
}


#elif defined(__SYCL_RT_OS_WINDOWS)
// TODO: Just inline it.
using OSModuleHandle = intptr_t;
Expand Down Expand Up @@ -192,6 +198,26 @@ std::string OSUtil::getDirName(const char *Path) {
return Tmp;
}

std::filesystem::path OSUtil::getCurrentDSODirPath() {
wchar_t Path[MAX_PATH];
//Path[0] = '\0';
//Path[sizeof(Path) - 1] = '\0';
auto Handle = getOSModuleHandle(reinterpret_cast<void *>(&getCurrentDSODir));
DWORD Ret = GetModuleFileName(
reinterpret_cast<HMODULE>(ExeModuleHandle == Handle ? 0 : Handle),
reinterpret_cast<LPWSTR>(&Path), sizeof(Path));
assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?");
assert(Ret > 0 && "GetModuleFileName failed");
(void)Ret;

BOOL RetCode = PathRemoveFileSpec(reinterpret_cast<LPWSTR>(&Path));
assert(RetCode && "PathRemoveFileSpec failed");
(void)RetCode;

return std::filesystem::path(Path);
}


#elif defined(__SYCL_RT_OS_DARWIN)
std::string OSUtil::getCurrentDSODir() {
auto CurrentFunc = reinterpret_cast<const void *>(&getCurrentDSODir);
Expand All @@ -208,6 +234,10 @@ std::string OSUtil::getCurrentDSODir() {
return Path.substr(0, LastSlashPos);
}

std::filesystem::path OSUtil::getCurrentDSODirPath() {
return std::filesystem::path(OSUtil::getCurrentDSODir());
}

#endif // __SYCL_RT_OS

size_t OSUtil::getOSMemSize() {
Expand Down
8 changes: 4 additions & 4 deletions sycl/source/detail/pi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <sycl/detail/pi.hpp>
#include <sycl/detail/stl_type_traits.hpp>
#include <sycl/version.hpp>
#include <filesystem>

#include <bitset>
#include <cstdarg>
Expand Down Expand Up @@ -385,7 +386,7 @@ std::vector<std::pair<std::string, backend>> findPlugins() {

// Load the Plugin by calling the OS dependent library loading call.
// Return the handle to the Library.
void *loadPlugin(const std::string &PluginPath) {
void *loadPlugin(const std::filesystem::path &PluginPath) {
return loadOsPluginLibrary(PluginPath);
}

Expand Down Expand Up @@ -442,15 +443,14 @@ static void initializePlugins(std::vector<PluginPtr> &Plugins) {
std::cerr << "SYCL_PI_TRACE[all]: "
<< "No Plugins Found." << std::endl;

const std::string LibSYCLDir =
sycl::detail::OSUtil::getCurrentDSODir() + sycl::detail::OSUtil::DirSep;
std::filesystem::path LibSYCLDir = sycl::detail::OSUtil::getCurrentDSODirPath();

for (unsigned int I = 0; I < PluginNames.size(); I++) {
std::shared_ptr<PiPlugin> PluginInformation = std::make_shared<PiPlugin>(
PiPlugin{_PI_H_VERSION_STRING, _PI_H_VERSION_STRING,
/*Targets=*/nullptr, /*FunctionPointers=*/{}});

void *Library = loadPlugin(LibSYCLDir + PluginNames[I].first);
void *Library = loadPlugin(LibSYCLDir / PluginNames[I].first); // loadPlugin(path)

if (!Library) {
if (trace(PI_TRACE_ALL)) {
Expand Down
9 changes: 5 additions & 4 deletions sycl/source/detail/posix_pi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,26 @@

#include <dlfcn.h>
#include <string>
#include <filesystem>

namespace sycl {
inline namespace _V1 {
namespace detail::pi {

void *loadOsLibrary(const std::string &LibraryPath) {
void *loadOsLibrary(const std::filesystem::path &LibraryPath) {
// TODO: Check if the option RTLD_NOW is correct. Explore using
// RTLD_DEEPBIND option when there are multiple plugins.
void *so = dlopen(LibraryPath.c_str(), RTLD_NOW);
void *so = dlopen(LibraryPath.string().c_str(), RTLD_NOW);
if (!so && trace(TraceLevel::PI_TRACE_ALL)) {
char *Error = dlerror();
std::cerr << "SYCL_PI_TRACE[-1]: dlopen(" << LibraryPath
std::cerr << "SYCL_PI_TRACE[-1]: dlopen(" << LibraryPath.string()
<< ") failed with <" << (Error ? Error : "unknown error") << ">"
<< std::endl;
}
return so;
}

void *loadOsPluginLibrary(const std::string &PluginPath) {
void *loadOsPluginLibrary(const std::filesystem::path &PluginPath) {
return loadOsLibrary(PluginPath);
}

Expand Down
11 changes: 6 additions & 5 deletions sycl/source/detail/windows_pi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>
#include <windows.h>
#include <winreg.h>
#include <filesystem>

#include "pi_win_proxy_loader.hpp"

Expand All @@ -20,27 +21,27 @@ inline namespace _V1 {
namespace detail {
namespace pi {

void *loadOsLibrary(const std::string &LibraryPath) {
void *loadOsLibrary(const std::filesystem::path &LibraryPath) {
// Tells the system to not display the critical-error-handler message box.
// Instead, the system sends the error to the calling process.
// This is crucial for graceful handling of shared libs that can't be
// loaded, e.g. due to missing native run-times.

UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS);
// Exclude current directory from DLL search path
if (!SetDllDirectoryA("")) {
if (!SetDllDirectory(L"")) {
assert(false && "Failed to update DLL search path");
}
auto Result = (void *)LoadLibraryA(LibraryPath.c_str());
auto Result = (void *)LoadLibrary(LibraryPath.wstring().c_str());
(void)SetErrorMode(SavedMode);
if (!SetDllDirectoryA(nullptr)) {
if (!SetDllDirectory(nullptr)) {
assert(false && "Failed to restore DLL search path");
}

return Result;
}

void *loadOsPluginLibrary(const std::string &PluginPath) {
void *loadOsPluginLibrary(const std::filesystem::path &PluginPath) {
// We fetch the preloaded plugin from the pi_win_proxy_loader.
// The proxy_loader handles any required error suppression.
auto Result = getPreloadedPlugin(PluginPath);
Expand Down

0 comments on commit 3c334c2

Please sign in to comment.