@@ -67,37 +67,42 @@ __forceinline__ __device__ T blockReduceSum(T val) {
6767}
6868
6969template <typename T>
70- __global__ void set_zero (T *x, int num) {
71- for (int i = blockIdx.x * blockDim.x + threadIdx.x ; i < num;
70+ __global__ void set_zero (T *x, int64_t num) {
71+ for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x ; i < num;
7272 i += blockDim.x * gridDim.x )
7373 x[i] = static_cast <T>(0 );
7474}
7575
7676template <typename T>
7777__global__ void channel_first (const T *input,
7878 T *rinput,
79- const int channel,
80- const int height,
81- const int width,
79+ const int64_t N,
80+ const int64_t channel,
81+ const int64_t H,
82+ const int64_t W,
8283 const int pad_size) {
83- int n = blockIdx.x ;
84- int h = blockIdx.y ;
85- int w = blockIdx.z ;
86-
87- int ch_off = threadIdx.x ;
88- T value;
89- int dimchw = channel * height * width;
90- int dimhw = height * width;
91-
92- int p_dimw = (width + 2 * pad_size);
93- int p_dimh = (height + 2 * pad_size);
94- int p_dimchw = channel * p_dimw * p_dimh;
95- int p_dimcw = channel * p_dimw;
96-
97- for (int c = ch_off; c < channel; c += THREADS_PER_BLOCK) {
98- value = input[n * dimchw + c * dimhw + h * width + w];
99- rinput[n * p_dimchw + (h + pad_size) * p_dimcw + (w + pad_size) * channel +
100- c] = value;
84+ int64_t global_idx = static_cast <int64_t >(blockIdx.x );
85+ int64_t stride = static_cast <int64_t >(gridDim.x );
86+
87+ int p_H = H + 2 * pad_size;
88+ int p_W = W + 2 * pad_size;
89+ int64_t p_dimcw = channel * p_W;
90+ int64_t p_dimchw = channel * p_H * p_W;
91+
92+ while (global_idx < int64_t (N) * H * W) {
93+ int64_t idx = global_idx;
94+ int64_t n = idx / (H * W);
95+ idx = idx % (H * W);
96+ int64_t h = idx / W;
97+ int64_t w = idx % W;
98+
99+ for (int64_t c = threadIdx.x ; c < channel; c += blockDim.x ) {
100+ rinput[n * p_dimchw + (h + pad_size) * p_dimcw +
101+ (w + pad_size) * channel + c] =
102+ input[n * (channel * H * W) + c * (H * W) + h * W + w];
103+ }
104+
105+ global_idx += stride;
101106 }
102107}
103108
0 commit comments