Skip to content

Commit 42124bb

Browse files
pradeepmlloreda
authored andcommitted
Fix c32, c64 multiplication in convolution kernels
1 parent 20403a6 commit 42124bb

File tree

5 files changed

+76
-26
lines changed

5 files changed

+76
-26
lines changed

src/backend/opencl/kernel/convolve.cl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ void convolve(global T *out, KParam oInfo, global T const *signal, KParam sInfo,
5454
int lx = get_local_id(0) + padding + (EXPAND ? 0 : fLen>>1);
5555
accType accum = (accType)(0);
5656
for(int f=0; f<fLen; ++f) {
57-
accum = accum + ((accType)localMem[lx-f] * (accType)impulse[f]);
57+
//binOp will do MUL_OP for convolution operation
58+
accum = accum + binOp((accType)localMem[lx-f], (accType)impulse[f]);
5859
}
5960
dst[gx] = (T)accum;
6061
}
@@ -122,7 +123,9 @@ void convolve(global T *out, KParam oInfo, global T const *signal, KParam sInfo,
122123
for(int fi=0; fi<FLEN0; ++fi) {
123124
accType f_val = impulse[fj*FLEN0+fi];
124125
T s_val = localMem[(cj-fj)*shrdLen0+(ci-fi)];
125-
accum = accum + ((accType)s_val*(accType)f_val);
126+
127+
//binOp will do MUL_OP for convolution operation
128+
accum = accum + binOp((accType)s_val, (accType)f_val);
126129
}
127130
}
128131
dst[gy*oInfo.strides[1]+gx] = (T)accum;
@@ -202,7 +205,9 @@ void convolve(global T *out, KParam oInfo, global T const *signal, KParam sInfo,
202205
for(int fi=0; fi<fLen0; ++fi) {
203206
accType f_val = impulse[index(fi, fj, fk, fLen0, fStride)];
204207
T s_val = localMem[index(ci-fi, cj-fj, ck-fk, shrdLen0, skStride)];
205-
accum = accum + ((accType)s_val*(accType)f_val);
208+
209+
//binOp will do MUL_OP for convolution operation
210+
accum = accum + binOp((accType)s_val, (accType)f_val);
206211
}
207212
}
208213
}

src/backend/opencl/kernel/convolve/conv2_impl.hpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,31 @@ void conv2Helper(const conv_kparam_t& param, Param out, const Param signal, cons
4242
size_t LOC_SIZE = (THREADS_X+2*(f0-1))*(THREADS_Y+2*(f1-1));
4343

4444
std::ostringstream options;
45-
options << " -D T=" << dtype_traits<T>::getName()
46-
<< " -D accType="<< dtype_traits<aT>::getName()
47-
<< " -D BASE_DIM="<< 2 /* hard constant specific to this convolution type */
48-
<< " -D FLEN0=" << f0
49-
<< " -D FLEN1=" << f1
50-
<< " -D EXPAND="<< expand
51-
<< " -D C_SIZE="<< LOC_SIZE;
52-
if (std::is_same<T, double>::value ||
53-
std::is_same<T, cdouble>::value) {
54-
options << " -D USE_DOUBLE";
45+
options << " -D T=" << dtype_traits<T>::getName()
46+
<< " -D Ti=" << dtype_traits<T>::getName()
47+
<< " -D To=" << dtype_traits<aT>::getName()
48+
<< " -D accType=" << dtype_traits<aT>::getName()
49+
<< " -D BASE_DIM=" << 2 /* hard constant specific to this convolution type */
50+
<< " -D FLEN0=" << f0
51+
<< " -D FLEN1=" << f1
52+
<< " -D EXPAND=" << expand
53+
<< " -D C_SIZE=" << LOC_SIZE
54+
<< " -D " << binOpName<af_mul_t>();
55+
56+
if((af_dtype) dtype_traits<T>::af_type == c32 ||
57+
(af_dtype) dtype_traits<T>::af_type == c64) {
58+
options << " -D CPLX=1";
59+
} else {
60+
options << " -D CPLX=0";
5561
}
62+
if (std::is_same<T, double>::value || std::is_same<T, cdouble>::value)
63+
options << " -D USE_DOUBLE";
64+
65+
const char *ker_strs[] = {ops_cl, convolve_cl};
66+
const int ker_lens[] = {ops_cl_len, convolve_cl_len};
5667
Program prog;
57-
buildProgram(prog, convolve_cl, convolve_cl_len, options.str());
68+
buildProgram(prog, 2, ker_strs, ker_lens, options.str());
69+
5870
entry.prog = new Program(prog);
5971
entry.ker = new Kernel(*entry.prog, "convolve");
6072

src/backend/opencl/kernel/convolve/conv_common.hpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111
#include <af/defines.h>
1212

13+
#include <kernel_headers/ops.hpp>
1314
#include <kernel_headers/convolve.hpp>
1415

1516
#include <string>
@@ -22,6 +23,7 @@
2223
#include <platform.hpp>
2324
#include <debug_opencl.hpp>
2425
#include <cache.hpp>
26+
#include <kernel/names.hpp>
2527

2628
using cl::Buffer;
2729
using cl::Program;
@@ -105,13 +107,27 @@ void convNHelper(const conv_kparam_t& param, Param& out, const Param& signal, co
105107
if (entry.prog==0 && entry.ker==0) {
106108
std::ostringstream options;
107109
options << " -D T=" << dtype_traits<T>::getName()
110+
<< " -D Ti=" << dtype_traits<T>::getName()
111+
<< " -D To=" << dtype_traits<aT>::getName()
108112
<< " -D accType=" << dtype_traits<aT>::getName()
109113
<< " -D BASE_DIM=" << bDim
110-
<< " -D EXPAND=" << expand;
114+
<< " -D EXPAND=" << expand
115+
<< " -D " << binOpName<af_mul_t>();
116+
117+
if((af_dtype) dtype_traits<T>::af_type == c32 ||
118+
(af_dtype) dtype_traits<T>::af_type == c64) {
119+
options << " -D CPLX=1";
120+
} else {
121+
options << " -D CPLX=0";
122+
}
111123
if (std::is_same<T, double>::value || std::is_same<T, cdouble>::value)
112124
options << " -D USE_DOUBLE";
125+
126+
const char *ker_strs[] = {ops_cl, convolve_cl};
127+
const int ker_lens[] = {ops_cl_len, convolve_cl_len};
113128
Program prog;
114-
buildProgram(prog, convolve_cl, convolve_cl_len, options.str());
129+
buildProgram(prog, 2, ker_strs, ker_lens, options.str());
130+
115131
entry.prog = new Program(prog);
116132
entry.ker = new Kernel(*entry.prog, "convolve");
117133

src/backend/opencl/kernel/convolve_separable.cl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ void convolve(global T *out, KParam oInfo, global T const *signal,
7171
// below conditional statement is based on MACRO value passed while kernel compilation
7272
int s_idx = (CONV_DIM==0 ? (ly*shrdLen+(i-f)) : ((i-f)*shrdLen+lx));
7373
T s_val = localMem[s_idx];
74-
accum = accum + ((accType)s_val*(accType)f_val);
74+
75+
//binOp will do MUL_OP for convolution operation
76+
accum = accum + binOp((accType)s_val, (accType)f_val);
7577
}
7678
dst[oy*oInfo.strides[1]+ox] = (T)accum;
7779
}

src/backend/opencl/kernel/convolve_separable.cpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10+
#include <kernel_headers/ops.hpp>
1011
#include <kernel_headers/convolve_separable.hpp>
12+
1113
#include <program.hpp>
1214
#include <traits.hpp>
1315
#include <string>
@@ -18,6 +20,7 @@
1820
#include <debug_opencl.hpp>
1921
#include <memory.hpp>
2022
#include <cache.hpp>
23+
#include <kernel/names.hpp>
2124

2225
using cl::Buffer;
2326
using cl::Program;
@@ -64,18 +67,30 @@ void convSep(Param out, const Param signal, const Param filter)
6467
size_t locSize = (conv_dim==0 ? C0_SIZE : C1_SIZE);
6568

6669
std::ostringstream options;
67-
options << " -D T=" << dtype_traits<T>::getName()
68-
<< " -D accType="<< dtype_traits<accType>::getName()
69-
<< " -D CONV_DIM="<< conv_dim
70-
<< " -D EXPAND="<< expand
71-
<< " -D FLEN="<< fLen
72-
<< " -D LOCAL_MEM_SIZE="<<locSize;
73-
if (std::is_same<T, double>::value ||
74-
std::is_same<T, cdouble>::value) {
70+
options << " -D T=" << dtype_traits<T>::getName()
71+
<< " -D Ti=" << dtype_traits<T>::getName()
72+
<< " -D To=" << dtype_traits<accType>::getName()
73+
<< " -D accType=" << dtype_traits<accType>::getName()
74+
<< " -D CONV_DIM=" << conv_dim
75+
<< " -D EXPAND=" << expand
76+
<< " -D FLEN=" << fLen
77+
<< " -D LOCAL_MEM_SIZE="<<locSize
78+
<< " -D " << binOpName<af_mul_t>();
79+
80+
if((af_dtype) dtype_traits<T>::af_type == c32 ||
81+
(af_dtype) dtype_traits<T>::af_type == c64) {
82+
options << " -D CPLX=1";
83+
} else {
84+
options << " -D CPLX=0";
85+
}
86+
if (std::is_same<T, double>::value || std::is_same<T, cdouble>::value) {
7587
options << " -D USE_DOUBLE";
7688
}
89+
90+
const char *ker_strs[] = {ops_cl, convolve_separable_cl};
91+
const int ker_lens[] = {ops_cl_len, convolve_separable_cl_len};
7792
Program prog;
78-
buildProgram(prog, convolve_separable_cl, convolve_separable_cl_len, options.str());
93+
buildProgram(prog, 2, ker_strs, ker_lens, options.str());
7994

8095
entry.prog = new Program(prog);
8196
entry.ker = new Kernel(*entry.prog, "convolve");

0 commit comments

Comments
 (0)