Skip to content

Commit

Permalink
Fixing device selector. (codeplaysoftware#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdi-goli authored and Adam Harries committed Aug 28, 2018
1 parent 7e43b67 commit 4a47a04
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 26 deletions.
23 changes: 16 additions & 7 deletions include/executors/executor_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ namespace blas {
template <class ExecutionPolicy>
class Executor {
public:
using Queue_Interface_Type = Queue_Interface<ExecutionPolicy>;
using Return_Type = void;
template <typename Tree>
void execute(Tree t) = delete;
template <typename T>
inline T* allocate(size_t num_bytes);
template <typename T>
inline void* deallocate(T* p);
inline Queue_Interface<ExecutionPolicy> get_policy_handler();
inline Queue_Interface_Type get_policy_handler();
template <typename first_event_t, typename... next_event_t>
void wait(first_event_t first_event, next_event_t... next_events);
void wait();
Expand All @@ -58,8 +60,12 @@ class Executor {
*/
template <>
class Executor<Sequential> {
public:
using Queue_Interface_Type = Queue_Interface<Sequential>;
using Return_Type = void;

private:
Queue_Interface<Sequential> q_interface;
Queue_Interface_Type q_interface;

public:
template <typename Tree>
Expand All @@ -70,9 +76,7 @@ class Executor<Sequential> {
}
};

inline Queue_Interface<Sequential> get_policy_handler() {
return q_interface;
}
inline Queue_Interface_Type get_policy_handler() { return q_interface; }
template <typename first_event_t, typename... next_event_t>
void wait(first_event_t, next_event_t...) {}
void wait() {}
Expand All @@ -84,7 +88,12 @@ class Executor<Sequential> {
*/
template <>
class Executor<Parallel> {
Queue_Interface<Parallel> q_interface;
public:
using Queue_Interface_Type = Queue_Interface<Parallel>;
using Return_Type = void;

private:
Queue_Interface_Type q_interface;

public:
template <typename Tree>
Expand All @@ -95,7 +104,7 @@ class Executor<Parallel> {
t.eval(i);
}
};
inline Queue_Interface<Parallel> get_policy_handler() { return q_interface; }
inline Queue_Interface_Type get_policy_handler() { return q_interface; }
template <typename first_event_t, typename... next_event_t>
void wait(first_event_t, next_event_t...) {}
void wait() {}
Expand Down
12 changes: 8 additions & 4 deletions include/executors/executor_sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,14 @@ static cl::sycl::event execute_tree(cl::sycl::queue q_, Tree t,
*/
template <>
class Executor<SYCL> {
Queue_Interface<SYCL> q_interface;

public:
using Queue_Interface_Type = Queue_Interface<SYCL>;
using Return_Type = cl::sycl::event;

private:
Queue_Interface_Type q_interface;

public:
template <
typename T,
cl::sycl::access::mode AcM = cl::sycl::access::mode::read_write,
Expand All @@ -260,11 +264,11 @@ class Executor<SYCL> {
*/
Executor(cl::sycl::queue q) : q_interface(q){};

inline Queue_Interface<SYCL> get_policy_handler() { return q_interface; }
inline Queue_Interface_Type get_policy_handler() { return q_interface; }

cl::sycl::queue get_queue() const { return q_interface.get_queue(); }

inline Queue_Interface<SYCL>::device_type get_device_type() {
inline Queue_Interface_Type::device_type get_device_type() {
return q_interface.get_device_type();
}

Expand Down
36 changes: 27 additions & 9 deletions include/interface/blas3_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ typename Executor::Return_Type _select_gemm(
auto buffer_a = make_matrix_view(ex, _A, _M, _K, _lda, 0);
auto buffer_b = make_matrix_view(ex, _B, _K, _N, _ldb, 0);
auto buffer_c = make_matrix_view(ex, _C, _M, _N, _ldc, 0);

#ifndef NAIVE_GEMM
#define ENABLE_GEMM_TRANSPOSE(_trans_a, _trans_b) \
if (_TransA == _trans_a && _TransB == _trans_b) { \
if (ex.has_local_memory()) { \
if (ex.has_local_memory() && \
(ex.get_device_type() != \
Executor::Queue_Interface_Type::device_type::SYCL_RCAR_CVENGINE) && \
(ex.get_device_type() != \
Executor::Queue_Interface_Type::device_type::SYCL_RCAR_HOST_CPU)) { \
auto gemm = make_gemm<DoubleBuffer, ConflictA, ConflictB, ClSize, TileT, \
_trans_a, _trans_b>(buffer_a, buffer_b, buffer_c, \
T(_alpha), T(_beta)); \
Expand All @@ -73,7 +77,15 @@ typename Executor::Return_Type _select_gemm(
} \
return ret; \
}

#else
#define ENABLE_GEMM_TRANSPOSE(_trans_a, _trans_b) \
if (_TransA == _trans_a && _TransB == _trans_b) { \
auto gemm = make_gemm_no_local_mem<WgSize, _trans_a, _trans_b>( \
buffer_a, buffer_b, buffer_c, T(_alpha), T(_beta)); \
ret = ex.gemm_executor(gemm); \
return ret; \
}
#endif
const bool NoTrans = false;
const bool Trans = true;

Expand All @@ -92,12 +104,12 @@ typename Executor::Return_Type _select_gemm(
*
* See netlib.org/blas for details.
*/
template <typename ExecutorType, typename ContainerT0, typename ContainerT1,
template <typename Executor, typename ContainerT0, typename ContainerT1,
typename ContainerT2, typename T, typename IndexType>
cl::sycl::event _gemm(Executor<ExecutorType>& ex, char _TransA, char _TransB,
IndexType _M, IndexType _N, IndexType _K, T _alpha,
ContainerT0 _A, IndexType _lda, ContainerT1 _B,
IndexType _ldb, T _beta, ContainerT2 _C, IndexType _ldc) {
cl::sycl::event _gemm(Executor& ex, char _TransA, char _TransB, IndexType _M,
IndexType _N, IndexType _K, T _alpha, ContainerT0 _A,
IndexType _lda, ContainerT1 _B, IndexType _ldb, T _beta,
ContainerT2 _C, IndexType _ldc) {
_TransA = tolower(_TransA);
_TransB = tolower(_TransB);

Expand All @@ -121,10 +133,16 @@ cl::sycl::event _gemm(Executor<ExecutorType>& ex, char _TransA, char _TransB,
_ldc); \
}
#ifndef NAIVE_GEMM
if (ex.get_device_type() == Queue_Interface<SYCL>::device_type::INTELGPU) {
if (ex.get_device_type() ==
Executor::Queue_Interface_Type::device_type::SYCL_INTEL_GPU) {
BIND_DATA_SIZE(1024, 4096, 1024) TO_TPARAMS(128, false, 4, 4, 16, 16);
BIND_DATA_SIZE(10, 1024, 1024) TO_TPARAMS(128, false, 2, 2, 8, 8);
BIND_DEFAULT TO_TPARAMS(128, false, 8, 8, 8, 8);
} else if ((ex.get_device_type() == Executor::Queue_Interface_Type::
device_type::SYCL_RCAR_CVENGINE) &&
(ex.get_device_type() == Executor::Queue_Interface_Type::
device_type::SYCL_RCAR_HOST_CPU)) {
BIND_DEFAULT TO_TPARAMS(32, false, 8, 8, 8, 8);
} else {
BIND_DATA_SIZE(10, 1024, 1024) TO_TPARAMS(128, true, 1, 1, 16, 16);
BIND_DEFAULT TO_TPARAMS(128, false, 8, 8, 16, 16);
Expand Down
30 changes: 24 additions & 6 deletions include/queue/queue_sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ class Queue_Interface<SYCL> {
using generic_buffer_data_type = cl::sycl::codeplay::buffer_data_type_t;

public:
enum device_type { UNSUPPORTED_DEVICE, INTELGPU, AMDGPU };
enum device_type {
SYCL_CPU,
SYCL_HOST,
SYCL_UNSUPPORTED_DEVICE,
SYCL_INTEL_GPU,
SYCL_AMD_GPU,
SYCL_RCAR_CVENGINE,
SYCL_RCAR_HOST_CPU
};

explicit Queue_Interface(cl::sycl::queue q)
: q_(q),
Expand All @@ -62,14 +70,24 @@ class Queue_Interface<SYCL> {
auto platform = dev.get_platform();
auto plat_name =
platform.template get_info<cl::sycl::info::platform::name>();
auto device_type =
dev.template get_info<cl::sycl::info::device::device_type>();
std::transform(plat_name.begin(), plat_name.end(), plat_name.begin(),
::tolower);
if (plat_name.find("amd") != std::string::npos && dev.is_gpu()) {
return AMDGPU;
} else if (plat_name.find("intel") != std::string::npos && dev.is_gpu()) {
return INTELGPU;
if (plat_name.find("amd") != std::string::npos &&
device_type == cl::sycl::info::device_type::gpu) {
return SYCL_AMD_GPU;
} else if (plat_name.find("intel") != std::string::npos &&
device_type == cl::sycl::info::device_type::gpu) {
return SYCL_INTEL_GPU;
} else if (plat_name.find("computeaorta") != std::string::npos &&
device_type == cl::sycl::info::device_type::accelerator) {
return SYCL_RCAR_CVENGINE;
} else if (plat_name.find("computeaorta") != std::string::npos &&
device_type == cl::sycl::info::device_type::cpu) {
return SYCL_RCAR_HOST_CPU;
} else {
return UNSUPPORTED_DEVICE;
return SYCL_UNSUPPORTED_DEVICE;
}
throw std::runtime_error("couldn't find device");
}
Expand Down

0 comments on commit 4a47a04

Please sign in to comment.