Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions paddle/phi/kernels/funcs/correlation_funcs.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,37 +67,42 @@ __forceinline__ __device__ T blockReduceSum(T val) {
}

template <typename T>
__global__ void set_zero(T *x, int num) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
__global__ void set_zero(T *x, int64_t num) {
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x)
x[i] = static_cast<T>(0);
}

template <typename T>
__global__ void channel_first(const T *input,
T *rinput,
const int channel,
const int height,
const int width,
const int64_t N,
const int64_t channel,
const int64_t H,
const int64_t W,
const int pad_size) {
int n = blockIdx.x;
int h = blockIdx.y;
int w = blockIdx.z;

int ch_off = threadIdx.x;
T value;
int dimchw = channel * height * width;
int dimhw = height * width;

int p_dimw = (width + 2 * pad_size);
int p_dimh = (height + 2 * pad_size);
int p_dimchw = channel * p_dimw * p_dimh;
int p_dimcw = channel * p_dimw;

for (int c = ch_off; c < channel; c += THREADS_PER_BLOCK) {
value = input[n * dimchw + c * dimhw + h * width + w];
rinput[n * p_dimchw + (h + pad_size) * p_dimcw + (w + pad_size) * channel +
c] = value;
int64_t global_idx = static_cast<int64_t>(blockIdx.x);
int64_t stride = static_cast<int64_t>(gridDim.x);

int p_H = H + 2 * pad_size;
int p_W = W + 2 * pad_size;
int64_t p_dimcw = channel * p_W;
int64_t p_dimchw = channel * p_H * p_W;

while (global_idx < int64_t(N) * H * W) {
int64_t idx = global_idx;
int64_t n = idx / (H * W);
idx = idx % (H * W);
int64_t h = idx / W;
int64_t w = idx % W;

for (int64_t c = threadIdx.x; c < channel; c += blockDim.x) {
rinput[n * p_dimchw + (h + pad_size) * p_dimcw +
(w + pad_size) * channel + c] =
input[n * (channel * H * W) + c * (H * W) + h * W + w];
}

global_idx += stride;
}
}

Expand Down
Loading