@@ -82,7 +82,7 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
8282 int64_t n, c_out, out_y, out_x;
8383 Layout::unpack_indices (global_idx, P, n, c_out, out_y, out_x);
8484
85- T acc = 0 ;
85+ float acc = 0 . 0f ;
8686
8787 for (int64_t c_in = 0 ; c_in < P.IC ; ++c_in) {
8888 kernel_bounds bounds = calculate_kernel_bounds (out_x, out_y, P);
@@ -93,21 +93,15 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
9393 for (int64_t kx = bounds.x_min ; kx < bounds.x_max ; ++kx) {
9494 const int64_t in_x = calculate_input_coord (out_x, kx, P.ST_X , P.DL_X , P.PD_X );
9595
96- T input_val;
97- if (std::is_same<T, half>::value) {
98- input_val = __float2half (input[Layout::input_index (n, c_in, in_y, in_x, P)]);
99- } else {
100- input_val = input[Layout::input_index (n, c_in, in_y, in_x, P)];
101- }
102-
103- T kernel_val = kernel[Layout::kernel_index (c_out, c_in, ky, kx, P)];
96+ const float input_val = input[Layout::input_index (n, c_in, in_y, in_x, P)];
97+ const float kernel_val = kernel[Layout::kernel_index (c_out, c_in, ky, kx, P)];
10498 acc += (input_val * kernel_val);
10599 }
106100 }
107101 }
108102
109103 // [N, OC, OH, OW]
110- output[Layout::output_index (n, c_out, out_y, out_x, P)] = ( float ) acc;
104+ output[Layout::output_index (n, c_out, out_y, out_x, P)] = acc;
111105}
112106
113107template <typename T>
0 commit comments