Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optim upsample backward #9424

Merged
merged 13 commits into from
Nov 15, 2022
61 changes: 46 additions & 15 deletions oneflow/user/kernels/upsample_nearest_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,30 @@ __global__ void UpsampleNearest2DBackward(const int64_t elem_cnt, const T* dy_dp
}
}

template<typename T>
__global__ void UpsampleNearest2D2XBackward(const int32_t in_elem_cnt, const T* dy_dptr,
const int32_t dx_height, const int32_t dx_width,
T* dx_dptr) {
const int32_t dx_hw_size = dx_height * dx_width;
CUDA_1D_KERNEL_LOOP(index, in_elem_cnt) {
T dx_value = 0.0;
const int32_t nc_idx = index / dx_hw_size;
const int32_t dx_hw_off = index - nc_idx * dx_hw_size;
const int32_t dx_h = dx_hw_off / dx_width;
const int32_t dx_w = dx_hw_off - dx_h * dx_width;
const Pack2X<T>* dy_pack_dptr = reinterpret_cast<const Pack2X<T>*>(dy_dptr);
const Pack2X<T> dy_pack_value1 =
dy_pack_dptr[nc_idx * dx_hw_size * 2 + dx_h * 2 * dx_width + dx_w];
const Pack2X<T> dy_pack_value2 =
dy_pack_dptr[nc_idx * dx_hw_size * 2 + (dx_h * 2 + 1) * dx_width + dx_w];
dx_value += dy_pack_value1.x;
dx_value += dy_pack_value1.y;
dx_value += dy_pack_value2.x;
dx_value += dy_pack_value2.y;
dx_dptr[index] = dx_value;
}
}

template<typename T>
__global__ void UpsampleNearest3DForward(const int64_t elem_cnt, const T* in_dptr,
NdIndexOffsetHelper<int64_t, 5> in_helper,
Expand Down Expand Up @@ -188,8 +212,6 @@ class UpsampleNearestGrad1DGPUKernel final : public user_op::OpKernel {
void Compute(user_op::KernelComputeContext* ctx) const override {
user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);

Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,
dx_tensor->shape_view().elem_cnt() * sizeof(T));
const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>("output_size");
double height_scale = ctx->Attr<double>("scale_factor");
Expand All @@ -204,6 +226,8 @@ class UpsampleNearestGrad1DGPUKernel final : public user_op::OpKernel {
ctx->stream(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
} else {
Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,
dx_tensor->shape_view().elem_cnt() * sizeof(T));
NdIndexOffsetHelper<int64_t, 3> dy_helper(dy_tensor->shape_view().At(0),
dy_tensor->shape_view().At(1),
dy_tensor->shape_view().At(2));
Expand Down Expand Up @@ -263,7 +287,7 @@ class UpsampleNearest2DGPUKernel final : public user_op::OpKernel {
} else {
const int64_t n = x_tensor->shape_view().At(0);
const int64_t c = x_tensor->shape_view().At(1);
if (out_height == 2 * in_height && out_width == 2 * in_width && in_elem_cnt <= 1 << 30) {
if (out_height == 2 * in_height && out_width == 2 * in_width && in_elem_cnt <= 1 << 29) {
RUN_CUDA_KERNEL(UpsampleNearest2D2XForward<T>, ctx->stream(), in_elem_cnt, in_elem_cnt,
x_tensor->dptr<T>(), in_height, in_width, y_tensor->mut_dptr<T>());
} else {
Expand All @@ -289,13 +313,12 @@ class UpsampleNearest2DGradGPUKernel final : public user_op::OpKernel {
void Compute(user_op::KernelComputeContext* ctx) const override {
user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);

Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,
dx_tensor->shape_view().elem_cnt() * sizeof(T));
const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
const std::vector<int64_t> output_size = ctx->Attr<std::vector<int64_t>>("output_size");
double height_scale = ctx->Attr<double>("height_scale");
double width_scale = ctx->Attr<double>("width_scale");
const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt();
const int64_t in_elem_cnt = dx_tensor->shape_view().elem_cnt();
const int64_t in_height = dx_tensor->shape_view().At(2);
const int64_t in_width = dx_tensor->shape_view().At(3);
const int64_t out_height = dy_tensor->shape_view().At(2);
Expand All @@ -309,16 +332,24 @@ class UpsampleNearest2DGradGPUKernel final : public user_op::OpKernel {
ctx->stream(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
dy_tensor->shape_view().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
} else {
NdIndexOffsetHelper<int64_t, 4> dy_helper(
dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1),
dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3));
NdIndexOffsetHelper<int64_t, 4> dx_helper(
dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1),
dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3));
RUN_CUDA_KERNEL((UpsampleNearest2DBackward<T>), ctx->stream(), elem_cnt, elem_cnt,
dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape_view().At(2),
dx_tensor->shape_view().At(3), 1.f / height_scale, 1.f / width_scale,
dx_tensor->mut_dptr<T>());
if (out_height == 2 * in_height && out_width == 2 * in_width && in_elem_cnt <= 1 << 29) {
RUN_CUDA_KERNEL(UpsampleNearest2D2XBackward<T>, ctx->stream(), in_elem_cnt, in_elem_cnt,
dy_tensor->dptr<T>(), dx_tensor->shape_view().At(2),
dx_tensor->shape_view().At(3), dx_tensor->mut_dptr<T>());
} else {
Memset<DeviceType::kCUDA>(ctx->stream(), dx_tensor->mut_dptr<T>(), 0,
dx_tensor->shape_view().elem_cnt() * sizeof(T));
NdIndexOffsetHelper<int64_t, 4> dy_helper(
dy_tensor->shape_view().At(0), dy_tensor->shape_view().At(1),
dy_tensor->shape_view().At(2), dy_tensor->shape_view().At(3));
NdIndexOffsetHelper<int64_t, 4> dx_helper(
dx_tensor->shape_view().At(0), dx_tensor->shape_view().At(1),
dx_tensor->shape_view().At(2), dx_tensor->shape_view().At(3));
RUN_CUDA_KERNEL((UpsampleNearest2DBackward<T>), ctx->stream(), elem_cnt, elem_cnt,
dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape_view().At(2),
dx_tensor->shape_view().At(3), 1.f / height_scale, 1.f / width_scale,
dx_tensor->mut_dptr<T>());
}
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
Expand Down