Skip to content
61 changes: 46 additions & 15 deletions oneflow/user/kernels/upsample_nearest_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,30 @@ __global__ void UpsampleNearest2DBackward(const int64_t elem_cnt, const T* dy_dp
}
}

template<typename T>
__global__ void UpsampleNearest2D2XBackward(const int32_t in_elem_cnt, const T* dy_dptr,
const int32_t dx_height, const int32_t dx_width,
T* dx_dptr) {
const int32_t dx_hw_size = dx_height * dx_width;
CUDA_1D_KERNEL_LOOP(index, in_elem_cnt) {
T dx_value = 0.0;
const int32_t nc_idx = index / dx_hw_size;
const int32_t dx_hw_off = index - nc_idx * dx_hw_size;
const int32_t dx_h = dx_hw_off / dx_width;
const int32_t dx_w = dx_hw_off - dx_h * dx_width;
const Pack2X<T>* dy_pack_dptr = reinterpret_cast<const Pack2X<T>*>(dy_dptr);
const Pack2X<T> dy_pack_value1 =
dy_pack_dptr[nc_idx * dx_hw_size * 2 + dx_h * 2 * dx_width + dx_w];
const Pack2X<T> dy_pack_value2 =
dy_pack_dptr[nc_idx * dx_hw_size * 2 + (dx_h * 2 + 1) * dx_width + dx_w];
dx_value += dy_pack_value1.x;
dx_value += dy_pack_value1.y;
dx_value += dy_pack_value2.x;
dx_value += dy_pack_value2.y;
dx_dptr[index] = dx_value;
}
}

template<typename T>
__global__ void UpsampleNearest3DForward(const int64_t elem_cnt, const T* in_dptr,
NdIndexOffsetHelper<int64_t, 5> in_helper,
Expand Down Expand Up @@ -188,8 +212,6 @@ class UpsampleNearestGrad1DGPUKernel final : public user_op::OpKernel {
void Compute(user_op::KernelComputeContext* ctx) const override {
user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);

Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,
dx_tensor->shape_view().elem_cnt() * sizeof(T));
const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>("output_size");
double height_scale = ctx->Attr<double>("scale_factor");
Expand All @@ -204,6 +226,8 @@ class UpsampleNearestGrad1DGPUKernel final : public user_op::OpKernel {
ctx->stream(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
} else {
Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,
dx_tensor->shape_view().elem_cnt() * sizeof(T));
NdIndexOffsetHelper<int64_t, 3> dy_helper(dy_tensor->shape_view().At(0),
dy_tensor->shape_view().At(1),
dy_tensor->shape_view().At(2));
Expand Down Expand Up @@ -263,7 +287,7 @@ class UpsampleNearest2DGPUKernel final : public user_op::OpKernel {
} else {
const int64_t n = x_tensor->shape_view().At(0);
const int64_t c = x_tensor->shape_view().At(1);
if (out_height == 2 * in_height && out_width == 2 * in_width && in_elem_cnt <= 1 << 30) {
if (out_height == 2 * in_height && out_width == 2 * in_width && in_elem_cnt <= 1 << 29) {
RUN_CUDA_KERNEL(UpsampleNearest2D2XForward<T>, ctx->stream(), in_elem_cnt, in_elem_cnt,
x_tensor->dptr<T>(), in_height, in_width, y_tensor->mut_dptr<T>());
} else {
Expand All @@ -289,13 +313,12 @@ class UpsampleNearest2DGradGPUKernel final : public user_op::OpKernel {
void Compute(user_op::KernelComputeContext* ctx) const override {
user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);

Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,
dx_tensor->shape_view().elem_cnt() * sizeof(T));
const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>("output_size");
double height_scale = ctx->Attr<double>("height_scale");
double width_scale = ctx->Attr<double>("width_scale");
const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();
const int64_t in_elem_cnt = dx_tensor->shape_view().elem_cnt();
const int64_t in_height = dx_tensor->shape_view().At(2);
const int64_t in_width = dx_tensor->shape_view().At(3);
const int64_t out_height = dy_tensor->shape_view().At(2);
Expand All @@ -309,16 +332,24 @@ class UpsampleNearest2DGradGPUKernel final : public user_op::OpKernel {
ctx->stream(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
} else {
NdIndexOffsetHelper<int64_t, 4> dy_helper(
dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1),
dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3));
NdIndexOffsetHelper<int64_t, 4> dx_helper(
dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1),
dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3));
RUN_CUDA_KERNEL((UpsampleNearest2DBackward<T>), ctx->stream(), elem_cnt, elem_cnt,
dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape_view().At(2),
dx_tensor->shape_view().At(3), 1.f / height_scale, 1.f / width_scale,
dx_tensor->mut_dptr<T>());
if (out_height == 2 * in_height && out_width == 2 * in_width && in_elem_cnt <= 1 << 29) {
RUN_CUDA_KERNEL(UpsampleNearest2D2XBackward<T>, ctx->stream(), in_elem_cnt, in_elem_cnt,
dy_tensor->dptr<T>(), dx_tensor->shape_view().At(2),
dx_tensor->shape_view().At(3), dx_tensor->mut_dptr<T>());
} else {
Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,
dx_tensor->shape_view().elem_cnt() * sizeof(T));
NdIndexOffsetHelper<int64_t, 4> dy_helper(
dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1),
dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3));
NdIndexOffsetHelper<int64_t, 4> dx_helper(
dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1),
dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3));
RUN_CUDA_KERNEL((UpsampleNearest2DBackward<T>), ctx->stream(), elem_cnt, elem_cnt,
dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape_view().At(2),
dx_tensor->shape_view().At(3), 1.f / height_scale, 1.f / width_scale,
dx_tensor->mut_dptr<T>());
}
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
Expand Down