Skip to content

Commit 0d9b256

Browse files
cszdrgzhengshengning
authored andcommitted
correlation supports big tensor (PaddlePaddle#75383)
* fix * fix test * fix
1 parent ed270cd commit 0d9b256

File tree

4 files changed

+316
-211
lines changed

4 files changed

+316
-211
lines changed

paddle/phi/kernels/funcs/correlation_funcs.cu.h

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,37 +67,42 @@ __forceinline__ __device__ T blockReduceSum(T val) {
6767
}
6868

6969
template <typename T>
70-
__global__ void set_zero(T *x, int num) {
71-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
70+
__global__ void set_zero(T *x, int64_t num) {
71+
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
7272
i += blockDim.x * gridDim.x)
7373
x[i] = static_cast<T>(0);
7474
}
7575

7676
template <typename T>
7777
__global__ void channel_first(const T *input,
7878
T *rinput,
79-
const int channel,
80-
const int height,
81-
const int width,
79+
const int64_t N,
80+
const int64_t channel,
81+
const int64_t H,
82+
const int64_t W,
8283
const int pad_size) {
83-
int n = blockIdx.x;
84-
int h = blockIdx.y;
85-
int w = blockIdx.z;
86-
87-
int ch_off = threadIdx.x;
88-
T value;
89-
int dimchw = channel * height * width;
90-
int dimhw = height * width;
91-
92-
int p_dimw = (width + 2 * pad_size);
93-
int p_dimh = (height + 2 * pad_size);
94-
int p_dimchw = channel * p_dimw * p_dimh;
95-
int p_dimcw = channel * p_dimw;
96-
97-
for (int c = ch_off; c < channel; c += THREADS_PER_BLOCK) {
98-
value = input[n * dimchw + c * dimhw + h * width + w];
99-
rinput[n * p_dimchw + (h + pad_size) * p_dimcw + (w + pad_size) * channel +
100-
c] = value;
84+
int64_t global_idx = static_cast<int64_t>(blockIdx.x);
85+
int64_t stride = static_cast<int64_t>(gridDim.x);
86+
87+
int p_H = H + 2 * pad_size;
88+
int p_W = W + 2 * pad_size;
89+
int64_t p_dimcw = channel * p_W;
90+
int64_t p_dimchw = channel * p_H * p_W;
91+
92+
while (global_idx < int64_t(N) * H * W) {
93+
int64_t idx = global_idx;
94+
int64_t n = idx / (H * W);
95+
idx = idx % (H * W);
96+
int64_t h = idx / W;
97+
int64_t w = idx % W;
98+
99+
for (int64_t c = threadIdx.x; c < channel; c += blockDim.x) {
100+
rinput[n * p_dimchw + (h + pad_size) * p_dimcw +
101+
(w + pad_size) * channel + c] =
102+
input[n * (channel * H * W) + c * (H * W) + h * W + w];
103+
}
104+
105+
global_idx += stride;
101106
}
102107
}
103108

0 commit comments

Comments
 (0)