Skip to content

Commit 9b97c56

Browse files
authored
Merge pull request #19 from listenlink/upstream
intelblas_gemm clean patch
2 parents 74e9a03 + 461b83d commit 9b97c56

File tree

18 files changed

+219389
-57504
lines changed

18 files changed

+219389
-57504
lines changed

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ foreach(FILE ${LIBISAAC_SRC})
4747
set(LIBISAAC_SRC_STR "${_TMP} ${LIBISAAC_SRC_STR}")
4848
endforeach()
4949

50-
51-
5250
#Include directories
5351
set(INCLUDE_DIRECTORIES_STR)
5452
get_property(INCLUDE_DIRECTORIES_LST DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)

include/isaac/driver/dispatch.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,11 @@ class dispatch
116116
static cl_int clGetKernelWorkGroupInfo(cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, size_t *);
117117
static cl_kernel clCreateKernel(cl_program, const char *, cl_int *);
118118
static cl_mem clCreateBuffer(cl_context, cl_mem_flags, size_t, void *, cl_int *);
119+
static cl_mem clCreateImage(cl_context, cl_mem_flags, const cl_image_format *, const cl_image_desc *, void *, cl_int *);
119120
static cl_program clCreateProgramWithSource(cl_context, cl_uint, const char **, const size_t *, cl_int *);
120121
static cl_int clReleaseKernel(cl_kernel);
122+
static cl_int clEnqueueCopyBufferToImage(cl_command_queue, cl_mem, cl_mem, size_t, const size_t *, const size_t *, cl_uint, const cl_event *, cl_event *);
123+
static cl_int clSetEventCallback(cl_event, cl_int, void (CL_CALLBACK * /* pfn_notify */)(cl_event, cl_int, void *), void *);
121124

122125
//CUDA
123126
static CUresult cuCtxDestroy_v2(CUcontext ctx);
@@ -202,8 +205,11 @@ class dispatch
202205
static void* clGetKernelWorkGroupInfo_;
203206
static void* clCreateKernel_;
204207
static void* clCreateBuffer_;
208+
static void* clCreateImage_;
205209
static void* clCreateProgramWithSource_;
206210
static void* clReleaseKernel_;
211+
static void* clEnqueueCopyBufferToImage_;
212+
static void* clSetEventCallback_;
207213

208214
//CUDA
209215
static void* cuCtxDestroy_v2_;

include/isaac/jit/generation/gemm.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,39 @@ class cublas_gemm : public external_base
4747
bool init_;
4848
};
4949

50+
class intelblas_gemm : public external_base
51+
{
52+
bool init();
53+
public:
54+
intelblas_gemm(char A_trans, char B_trans);
55+
int is_invalid(expression_tree const &, driver::Device const &) const;
56+
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
57+
void enqueue(driver::CommandQueue & queue, driver::Program const &, std::string const &, runtime::execution_handler const & h);
58+
expression_type type() const;
59+
private:
60+
std::string generate_impl(std::string const & suffix, expression_tree const &, driver::Device const & device, symbolic::symbols_table const &) const;
61+
const char A_trans_;
62+
const char B_trans_;
63+
bool init_;
64+
};
65+
66+
class intelblas_gemm_image : public external_base
67+
{
68+
bool init();
69+
public:
70+
intelblas_gemm_image(char A_trans, char B_trans);
71+
int is_invalid(expression_tree const &, driver::Device const &) const;
72+
std::vector<int_t> input_sizes(expression_tree const & expressions) const;
73+
void enqueue(driver::CommandQueue & queue, driver::Program const &, std::string const &, runtime::execution_handler const & h);
74+
expression_type type() const;
75+
private:
76+
std::string generate_impl(std::string const & suffix, expression_tree const &, driver::Device const & device, symbolic::symbols_table const &) const;
77+
const char A_trans_;
78+
const char B_trans_;
79+
bool init_;
80+
};
81+
82+
5083
class gemm : public parameterized_base
5184
{
5285
private:

lib/api/blas/clBLAS.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ extern "C"
278278
}\
279279
sc::int_t As1 = (sc::int_t)M, As2 = (sc::int_t)K;\
280280
sc::int_t Bs1 = (sc::int_t)K, Bs2 = (sc::int_t)N;\
281-
if(transA==clblasTrans) std::swap(As1, As2);\
282-
if(transB==clblasTrans) std::swap(Bs1, Bs2);\
281+
if(transA==clblasTrans || transA==clblasConjTrans) std::swap(As1, As2);\
282+
if(transB==clblasTrans || transB==clblasConjTrans) std::swap(Bs1, Bs2);\
283283
/*Struct*/\
284284
sc::array A(As1, As2, TYPE_ISAAC, sc::driver::Buffer(mA, false), (sc::int_t)offA, (sc::int_t)lda);\
285285
sc::array B(Bs1, Bs2, TYPE_ISAAC, sc::driver::Buffer(mB, false), (sc::int_t)offB, (sc::int_t)ldb);\

lib/driver/device.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ std::string Device::infos() const
301301
return oss.str();
302302
}
303303

304+
Device::handle_type const & Device::handle() const
305+
{ return h_; }
306+
304307
// Properties
305308
#define WRAP_ATTRIBUTE(ret, fname, CUNAME, CLNAME) \
306309
ret Device::fname() const\

lib/driver/dispatch.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ cl_int dispatch::clBuildProgram(cl_program a, cl_uint b, const cl_device_id * c,
142142
cl_context dispatch::clCreateContext(const cl_context_properties * a, cl_uint b, const cl_device_id * c, void (*d)(const char *, const void *, size_t, void *), void * e, cl_int * f)
143143
{ return f_impl<dispatch::clinit>(opencl_, dispatch::clCreateContext, dispatch::clCreateContext_, "clCreateContext", a, b, c, d, e, f); }
144144

145+
cl_int dispatch::clSetEventCallback(cl_event event, cl_int a, void(CL_CALLBACK *pfn_notify)(cl_event, cl_int, void *), void * arg)
146+
{ return f_impl<dispatch::clinit>(opencl_, dispatch::clSetEventCallback, dispatch::clSetEventCallback_, "clSetEventCallback", event, a, pfn_notify, arg); }
147+
145148
OCL_DEFINE9(cl_int, clEnqueueNDRangeKernel, cl_command_queue, cl_kernel, cl_uint, const size_t*, const size_t*, const size_t*, cl_uint, const cl_event*, cl_event*)
146149
OCL_DEFINE4(cl_int, clSetKernelArg, cl_kernel, cl_uint, size_t, const void *)
147150
OCL_DEFINE1(cl_int, clReleaseMemObject, cl_mem)
@@ -171,8 +174,10 @@ OCL_DEFINE5(cl_int, clGetKernelInfo, cl_kernel, cl_kernel_info, size_t, void *,
171174
OCL_DEFINE6(cl_int, clGetKernelWorkGroupInfo, cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, size_t *)
172175
OCL_DEFINE3(cl_kernel, clCreateKernel, cl_program, const char *, cl_int *)
173176
OCL_DEFINE5(cl_mem, clCreateBuffer, cl_context, cl_mem_flags, size_t, void *, cl_int *)
177+
OCL_DEFINE6(cl_mem, clCreateImage, cl_context, cl_mem_flags, const cl_image_format *, const cl_image_desc *, void *, cl_int *)
174178
OCL_DEFINE5(cl_program, clCreateProgramWithSource, cl_context, cl_uint, const char **, const size_t *, cl_int *)
175179
OCL_DEFINE1(cl_int, clReleaseKernel, cl_kernel)
180+
OCL_DEFINE9(cl_int, clEnqueueCopyBufferToImage, cl_command_queue, cl_mem, cl_mem, size_t, const size_t *, const size_t *, cl_uint, const cl_event *, cl_event *)
176181

177182
//CUDA
178183
CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext)
@@ -291,8 +296,11 @@ void* dispatch::clGetKernelInfo_;
291296
void* dispatch::clGetKernelWorkGroupInfo_;
292297
void* dispatch::clCreateKernel_;
293298
void* dispatch::clCreateBuffer_;
299+
void* dispatch::clCreateImage_;
294300
void* dispatch::clCreateProgramWithSource_;
295301
void* dispatch::clReleaseKernel_;
302+
void* dispatch::clEnqueueCopyBufferToImage_;
303+
void* dispatch::clSetEventCallback_;
296304

297305
//CUDA
298306
void* dispatch::cuCtxDestroy_v2_;

lib/driver/kernel.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ void Kernel::setSizeArg(unsigned int index, size_t N)
127127
}
128128
}
129129

130+
Kernel::handle_type const & Kernel::handle() const
131+
{ return h_; }
132+
130133
}
131134

132135
}

lib/jit/generation/elementwise_2d.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include "tools/loop.hpp"
2929
#include "tools/vector_types.hpp"
3030

31-
3231
namespace isaac
3332
{
3433
namespace templates

0 commit comments

Comments
 (0)