Skip to content

[ur] Enable better formating for tests #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ BasedOnStyle: LLVM
IndentWidth: 4
InsertBraces: true
ReflowComments: false
ColumnLimit: 0
...
48 changes: 30 additions & 18 deletions examples/collector/collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@
#include "ur_api.h"
#include "xpti/xpti_trace_framework.h"

constexpr uint16_t TRACE_FN_BEGIN = static_cast<uint16_t>(xpti::trace_point_type_t::function_with_args_begin);
constexpr uint16_t TRACE_FN_END = static_cast<uint16_t>(xpti::trace_point_type_t::function_with_args_end);
constexpr uint16_t TRACE_FN_BEGIN =
static_cast<uint16_t>(xpti::trace_point_type_t::function_with_args_begin);
constexpr uint16_t TRACE_FN_END =
static_cast<uint16_t>(xpti::trace_point_type_t::function_with_args_end);
constexpr std::string_view UR_STREAM_NAME = "ur";

/**
* @brief Formats the function parameters and arguments for urInit
*/
std::ostream &operator<<(std::ostream &os, const struct ur_init_params_t *params) {
std::ostream &operator<<(std::ostream &os,
const struct ur_init_params_t *params) {
os << ".device_flags = ";
if (*params->pdevice_flags & UR_DEVICE_INIT_FLAG_GPU) {
os << "UR_DEVICE_INIT_FLAG_GPU";
Expand All @@ -48,14 +51,15 @@ std::ostream &operator<<(std::ostream &os, const struct ur_init_params_t *params
* This example only implements a handler for one function, `urInit`, but it's
* trivial to expand it to support more.
*/
static std::unordered_map<std::string_view,
std::function<void(const xpti::function_with_args_t *, std::ostream &)>>
handlers =
{
{"urInit", [](const xpti::function_with_args_t *fn_args, std::ostream &os) {
auto params = static_cast<const struct ur_init_params_t *>(fn_args->args_data);
os << params;
}}};
static std::unordered_map<
std::string_view,
std::function<void(const xpti::function_with_args_t *, std::ostream &)>>
handlers = {{"urInit", [](const xpti::function_with_args_t *fn_args,
std::ostream &os) {
auto params = static_cast<const struct ur_init_params_t *>(
fn_args->args_data);
os << params;
}}};

/**
* @brief Tracing callback invoked by the dispatcher on every event.
Expand All @@ -74,7 +78,8 @@ XPTI_CALLBACK_API void trace_cb(uint16_t trace_type,
auto *args = static_cast<const xpti::function_with_args_t *>(user_data);
std::ostringstream out;
if (trace_type == TRACE_FN_BEGIN) {
out << "function_with_args_begin(" << instance << ") - " << args->function_name << "(";
out << "function_with_args_begin(" << instance << ") - "
<< args->function_name << "(";
auto it = handlers.find(args->function_name);
if (it == handlers.end()) {
out << "unimplemented";
Expand All @@ -84,7 +89,9 @@ XPTI_CALLBACK_API void trace_cb(uint16_t trace_type,
out << ");";
} else if (trace_type == TRACE_FN_END) {
auto result = static_cast<const ur_result_t *>(args->ret_data);
out << "function_with_args_end(" << instance << ") - " << args->function_name << "(...) -> ur_result_t(" << *result << ");";
out << "function_with_args_end(" << instance << ") - "
<< args->function_name << "(...) -> ur_result_t(" << *result
<< ");";
} else {
out << "unsupported trace type";
}
Expand All @@ -105,12 +112,18 @@ XPTI_CALLBACK_API void xptiTraceInit(unsigned int major_version,
const char *version_str,
const char *stream_name) {
if (!stream_name || std::string_view(stream_name) != UR_STREAM_NAME) {
std::cout << "Invalid stream name: " << stream_name << ". Expected " << UR_STREAM_NAME << ". Aborting." << std::endl;
std::cout << "Invalid stream name: " << stream_name << ". Expected "
<< UR_STREAM_NAME << ". Aborting." << std::endl;
return;
}

if (UR_MAKE_VERSION(major_version, minor_version) != UR_API_VERSION_CURRENT) {
std::cout << "Invalid stream version: " << major_version << "." << minor_version << ". Expected " << UR_MAJOR_VERSION(UR_API_VERSION_CURRENT) << "." << UR_MINOR_VERSION(UR_API_VERSION_CURRENT) << ". Aborting." << std::endl;
if (UR_MAKE_VERSION(major_version, minor_version) !=
UR_API_VERSION_CURRENT) {
std::cout << "Invalid stream version: " << major_version << "."
<< minor_version << ". Expected "
<< UR_MAJOR_VERSION(UR_API_VERSION_CURRENT) << "."
<< UR_MINOR_VERSION(UR_API_VERSION_CURRENT) << ". Aborting."
<< std::endl;
return;
}

Expand All @@ -130,6 +143,5 @@ XPTI_CALLBACK_API void xptiTraceInit(unsigned int major_version,
*
* Can be used to cleanup state or resources.
*/
XPTI_CALLBACK_API void xptiTraceFinish(const char *stream_name) {
/* noop */
XPTI_CALLBACK_API void xptiTraceFinish(const char *stream_name) { /* noop */
}
35 changes: 24 additions & 11 deletions examples/hello_world/hello_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,51 +29,64 @@ int main(int argc, char *argv[]) {

status = urPlatformGet(1, nullptr, &platformCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status << std::endl;
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
}

platforms.resize(platformCount);
status = urPlatformGet(platformCount, platforms.data(), nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status << std::endl;
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
}

for (auto p : platforms) {
ur_api_version_t api_version = {};
status = urPlatformGetApiVersion(p, &api_version);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGetApiVersion failed with return code: " << status << std::endl;
std::cout << "urPlatformGetApiVersion failed with return code: "
<< status << std::endl;
goto out;
}
std::cout << "API version: " << UR_MAJOR_VERSION(api_version) << "." << UR_MINOR_VERSION(api_version) << std::endl;
std::cout << "API version: " << UR_MAJOR_VERSION(api_version) << "."
<< UR_MINOR_VERSION(api_version) << std::endl;

uint32_t deviceCount = 0;
status = urDeviceGet(p, UR_DEVICE_TYPE_GPU, 0, nullptr, &deviceCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urDeviceGet failed with return code: " << status << std::endl;
std::cout << "urDeviceGet failed with return code: " << status
<< std::endl;
goto out;
}

std::vector<ur_device_handle_t> devices(deviceCount);
status = urDeviceGet(p, UR_DEVICE_TYPE_GPU, deviceCount, devices.data(), nullptr);
status = urDeviceGet(p, UR_DEVICE_TYPE_GPU, deviceCount, devices.data(),
nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urDeviceGet failed with return code: " << status << std::endl;
std::cout << "urDeviceGet failed with return code: " << status
<< std::endl;
goto out;
}
for (auto d : devices) {
ur_device_type_t device_type;
status = urDeviceGetInfo(d, UR_DEVICE_INFO_TYPE, sizeof(ur_device_type_t), static_cast<void *>(&device_type), nullptr);
status = urDeviceGetInfo(
d, UR_DEVICE_INFO_TYPE, sizeof(ur_device_type_t),
static_cast<void *>(&device_type), nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urDeviceGetInfo failed with return code: " << status << std::endl;
std::cout << "urDeviceGetInfo failed with return code: "
<< status << std::endl;
goto out;
}
static const size_t DEVICE_NAME_MAX_LEN = 1024;
char device_name[DEVICE_NAME_MAX_LEN] = {0};
status = urDeviceGetInfo(d, UR_DEVICE_INFO_NAME, DEVICE_NAME_MAX_LEN - 1, static_cast<void *>(&device_name), nullptr);
status =
urDeviceGetInfo(d, UR_DEVICE_INFO_NAME, DEVICE_NAME_MAX_LEN - 1,
static_cast<void *>(&device_name), nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urDeviceGetInfo failed with return code: " << status << std::endl;
std::cout << "urDeviceGetInfo failed with return code: "
<< status << std::endl;
goto out;
}
if (device_type == UR_DEVICE_TYPE_GPU) {
Expand Down
8 changes: 8 additions & 0 deletions include/.clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
Language: Cpp
BasedOnStyle: LLVM
IndentWidth: 4
InsertBraces: true
ReflowComments: false
ColumnLimit: 0
...
6 changes: 1 addition & 5 deletions source/common/logger/ur_level.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@

namespace logger {

enum class Level { DEBUG,
INFO,
WARN,
ERR,
QUIET };
enum class Level { DEBUG, INFO, WARN, ERR, QUIET };

inline constexpr auto level_to_str(Level level) {
switch (level) {
Expand Down
17 changes: 10 additions & 7 deletions source/common/logger/ur_logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ inline Logger &get_logger(std::string name = "common") {
return logger;
}

inline void init(std::string name) {
get_logger(name);
}
inline void init(std::string name) { get_logger(name); }

template <typename... Args>
inline void debug(const char *format, Args &&...args) {
Expand All @@ -45,7 +43,9 @@ inline void error(const char *format, Args &&...args) {

inline void setLevel(logger::Level level) { get_logger().setLevel(level); }

inline void setFlushLevel(logger::Level level) { get_logger().setFlushLevel(level); }
inline void setFlushLevel(logger::Level level) {
get_logger().setFlushLevel(level);
}

/// @brief Create an instance of the logger with parameters obtained from the respective
/// environment variable or with default configuration if the env var is empty,
Expand Down Expand Up @@ -102,10 +102,13 @@ inline Logger create_logger(std::string logger_name) {
values = kv->second;
}

sink = values.size() == 2 ? sink_from_str(logger_name, values[0], values[1])
: sink_from_str(logger_name, values[0]);
sink = values.size() == 2
? sink_from_str(logger_name, values[0], values[1])
: sink_from_str(logger_name, values[0]);
} catch (const std::invalid_argument &e) {
std::cerr << "Error when creating a logger instance from environment variable" << e.what();
std::cerr
<< "Error when creating a logger instance from environment variable"
<< e.what();
return Logger(std::make_unique<logger::StderrSink>(logger_name));
}
sink->setFlushLevel(flush_level);
Expand Down
9 changes: 3 additions & 6 deletions source/common/logger/ur_logger_details.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,11 @@ class Logger {
this->sink->setFlushLevel(level);
}

template <typename... Args>
void debug(const char *format, Args &&...args) {
template <typename... Args> void debug(const char *format, Args &&...args) {
log(logger::Level::DEBUG, format, std::forward<Args>(args)...);
}

template <typename... Args>
void info(const char *format, Args &&...args) {
template <typename... Args> void info(const char *format, Args &&...args) {
log(logger::Level::INFO, format, std::forward<Args>(args)...);
}

Expand All @@ -58,8 +56,7 @@ class Logger {
log(logger::Level::WARN, format, std::forward<Args>(args)...);
}

template <typename... Args>
void error(const char *format, Args &&...args) {
template <typename... Args> void error(const char *format, Args &&...args) {
log(logger::Level::ERR, format, std::forward<Args>(args)...);
}

Expand Down
46 changes: 31 additions & 15 deletions source/common/logger/ur_sinks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class Sink {
std::ostream *ostream;
logger::Level flush_level;

Sink(std::string logger_name) : logger_name(logger_name) { flush_level = logger::Level::ERR; }
Sink(std::string logger_name) : logger_name(logger_name) {
flush_level = logger::Level::ERR;
}

private:
std::string logger_name;
Expand All @@ -45,20 +47,21 @@ class Sink {
if (*(++fmt) == '{') {
*ostream << *fmt++;
} else {
throw std::runtime_error("No arguments provided and braces not escaped!");
throw std::runtime_error(
"No arguments provided and braces not escaped!");
}
} else if (*fmt == '}') {
if (*(++fmt) == '}') {
*ostream << *fmt++;
} else {
throw std::runtime_error("Closing curly brace not escaped!");
throw std::runtime_error(
"Closing curly brace not escaped!");
}
}
}
}

template <typename Arg>
void format(const char *fmt, Arg &&arg) {
template <typename Arg> void format(const char *fmt, Arg &&arg) {
while (*fmt != '\0') {
while (*fmt != '{' && *fmt != '}' && *fmt != '\0') {
*ostream << *fmt++;
Expand All @@ -77,7 +80,8 @@ class Sink {
if (*(++fmt) == '}') {
*ostream << *fmt++;
} else {
throw std::runtime_error("Closing curly brace not escaped!");
throw std::runtime_error(
"Closing curly brace not escaped!");
}
}
}
Expand All @@ -104,7 +108,8 @@ class Sink {
if (*(++fmt) == '}') {
*ostream << *fmt++;
} else {
throw std::runtime_error("Closing curly brace not escaped!");
throw std::runtime_error(
"Closing curly brace not escaped!");
}
}
}
Expand All @@ -115,9 +120,12 @@ class Sink {

class StdoutSink : public Sink {
public:
StdoutSink(std::string logger_name) : Sink(logger_name) { this->ostream = &std::cout; }
StdoutSink(std::string logger_name) : Sink(logger_name) {
this->ostream = &std::cout;
}

StdoutSink(std::string logger_name, Level flush_lvl) : StdoutSink(logger_name) {
StdoutSink(std::string logger_name, Level flush_lvl)
: StdoutSink(logger_name) {
this->flush_level = flush_lvl;
}

Expand All @@ -126,9 +134,12 @@ class StdoutSink : public Sink {

class StderrSink : public Sink {
public:
StderrSink(std::string logger_name) : Sink(logger_name) { this->ostream = &std::cerr; }
StderrSink(std::string logger_name) : Sink(logger_name) {
this->ostream = &std::cerr;
}

StderrSink(std::string logger_name, Level flush_lvl) : StderrSink(logger_name) {
StderrSink(std::string logger_name, Level flush_lvl)
: StderrSink(logger_name) {
this->flush_level = flush_lvl;
}

Expand All @@ -137,7 +148,8 @@ class StderrSink : public Sink {

class FileSink : public Sink {
public:
FileSink(std::string logger_name, std::string file_path) : Sink(logger_name) {
FileSink(std::string logger_name, std::string file_path)
: Sink(logger_name) {
ofstream = std::ofstream(file_path, std::ofstream::out);
if (ofstream.rdstate() != std::ofstream::goodbit) {
throw std::invalid_argument(
Expand All @@ -147,21 +159,25 @@ class FileSink : public Sink {
this->ostream = &ofstream;
}

FileSink(std::string logger_name, std::string file_path, Level flush_lvl) : FileSink(logger_name, file_path) {
FileSink(std::string logger_name, std::string file_path, Level flush_lvl)
: FileSink(logger_name, file_path) {
this->flush_level = flush_lvl;
}

private:
std::ofstream ofstream;
};

inline std::unique_ptr<Sink> sink_from_str(std::string logger_name, std::string name, std::string file_path = "") {
inline std::unique_ptr<Sink> sink_from_str(std::string logger_name,
std::string name,
std::string file_path = "") {
if (name == "stdout") {
return std::make_unique<logger::StdoutSink>(logger_name);
} else if (name == "stderr") {
return std::make_unique<logger::StderrSink>(logger_name);
} else if (name == "file" && !file_path.empty()) {
return std::make_unique<logger::FileSink>(logger_name, file_path.c_str());
return std::make_unique<logger::FileSink>(logger_name,
file_path.c_str());
}

throw std::invalid_argument(
Expand Down
Loading