Skip to content

Commit

Permalink
Feature/pad3d (apache#3742)
Browse files Browse the repository at this point in the history
* added support for padding 3d images

* fixed typos + added python test for 5d input.

* Fix cuda bugs.

* fixed cuda index bug
  • Loading branch information
sbodenstein authored and piiswrong committed Nov 6, 2016
1 parent 5409dde commit 9107fc3
Show file tree
Hide file tree
Showing 4 changed files with 486 additions and 75 deletions.
26 changes: 18 additions & 8 deletions src/operator/pad-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,16 @@ class PadOp : public Operator {
in_data[pad_enum::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> out =
out_data[pad_enum::kOut].get<xpu, 4, DType>(s);
pad_image_2d(out, data, param_.pad_width, param_.mode, constant_value);
pad_image(out, data, param_.pad_width, param_.mode, constant_value);
} else if ((rank == 5) && !pad[0] && !pad[1] && !pad[2] && !pad[3]) {
Tensor<xpu, 5, DType> data =
in_data[pad_enum::kData].get<xpu, 5, DType>(s);
Tensor<xpu, 5, DType> out =
out_data[pad_enum::kOut].get<xpu, 5, DType>(s);
pad_image(out, data, param_.pad_width, param_.mode, constant_value);
} else {
LOG(FATAL) << "Only 4d input tensors and padding applied to the last "
"two dimensions is currently implemented. ";
LOG(FATAL) << "Only 4d or 5d input tensors with padding applied to "
"dimensions > 1 is currently implemented.";
}

// Assign(out, req[pad_enum::kOut], F<mshadow_op::identity>(data));
Expand All @@ -104,17 +110,21 @@ class PadOp : public Operator {
// Get any size input + output into required form
auto pad = param_.pad_width;
int rank = in_grad[pad_enum::kData].ndim();
// Currently only support rank 4
if ((rank == 4) && !pad[0] && !pad[1] && !pad[2] && !pad[3]) {
Tensor<xpu, 4, DType> in = in_grad[pad_enum::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> out =
out_grad[pad_enum::kOut].get<xpu, 4, DType>(s);
if (req[pad_enum::kData] == kWriteTo) in = 0.0f;

pad_image_2d_grad(in, out, param_.pad_width, param_.mode);
pad_image_grad(in, out, param_.pad_width, param_.mode);
} else if ((rank == 5) && !pad[0] && !pad[1] && !pad[2] && !pad[3]) {
Tensor<xpu, 5, DType> in = in_grad[pad_enum::kData].get<xpu, 5, DType>(s);
Tensor<xpu, 5, DType> out =
out_grad[pad_enum::kOut].get<xpu, 5, DType>(s);
if (req[pad_enum::kData] == kWriteTo) in = 0.0f;
pad_image_grad(in, out, param_.pad_width, param_.mode);
} else {
LOG(FATAL) << "Only 4d input tensors and padding applied to the last "
"two dimensions is currently implemented. ";
LOG(FATAL) << "Only 4d and 5d input tensors with padding applied to "
"dimensions > 1 is currently implemented. ";
}
}

Expand Down
237 changes: 214 additions & 23 deletions src/operator/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace mshadow {
// single_image_2d_edge adapted from Torch
// https://github.com/torch/nn/blob/master/lib/THNN/generic/SpatialReplicationPadding.c
template <typename DType>
void single_image_2d_edge(const Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> src, mxnet::TShape pad) {
void single_image_edge(const Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> src, mxnet::TShape pad) {
const int nslices = src.size(0);
const int iheight = src.size(1);
const int iwidth = src.size(2);
Expand Down Expand Up @@ -63,9 +63,9 @@ void single_image_2d_edge(const Tensor<cpu, 3, DType> dst,
}

template <typename DType>
void single_image_2d_edge_grad(const Tensor<cpu, 3, DType> &grad_in,
const Tensor<cpu, 3, DType> grad_out,
mxnet::TShape pad) {
void single_image_edge_grad(const Tensor<cpu, 3, DType> &grad_in,
const Tensor<cpu, 3, DType> grad_out,
mxnet::TShape pad) {
const int nslices = grad_in.size(0);
const int iheight = grad_in.size(1);
const int iwidth = grad_in.size(2);
Expand Down Expand Up @@ -115,9 +115,9 @@ void single_image_2d_edge_grad(const Tensor<cpu, 3, DType> &grad_in,

// Case 2: Zero Padding
template <typename DType>
void single_image_2d_constant(const Tensor<cpu, 3, DType> &dst,
const Tensor<cpu, 3, DType> src,
mxnet::TShape pad, DType constant_value) {
void single_image_constant(const Tensor<cpu, 3, DType> &dst,
const Tensor<cpu, 3, DType> src, mxnet::TShape pad,
DType constant_value) {
const int pad_t = pad[4];
const int pad_l = pad[6];
int c, w, h;
Expand All @@ -137,9 +137,9 @@ void single_image_2d_constant(const Tensor<cpu, 3, DType> &dst,
}

template <typename DType>
void single_image_2d_constant_grad(const Tensor<cpu, 3, DType> &in_grad,
const Tensor<cpu, 3, DType> out_grad,
mxnet::TShape pad) {
void single_image_constant_grad(const Tensor<cpu, 3, DType> &in_grad,
const Tensor<cpu, 3, DType> out_grad,
mxnet::TShape pad) {
const int pad_t = pad[4];
const int pad_l = pad[6];
int c, h, w;
Expand All @@ -153,38 +153,229 @@ void single_image_2d_constant_grad(const Tensor<cpu, 3, DType> &in_grad,
}
}

// General 2d image case
////////////////////////////////////////////////////////////////////////////////
// Special Case: 3d image (so only pad width + height + depth)

// Case 1: Edge Padding (or Replication Padding)
// single_image_3d_edge adapted from Torch
// https://github.com/torch/nn/blob/master/lib/THNN/generic/VolumetricReplicationPadding.c
template <typename DType>
void single_image_edge(const Tensor<cpu, 4, DType> dst,
const Tensor<cpu, 4, DType> src, mxnet::TShape pad) {
const int nslices = src.size(0);
const int idepth = src.size(1);
const int iheight = src.size(2);
const int iwidth = src.size(3);

const int odepth = dst.size(1);
const int oheight = dst.size(2);
const int owidth = dst.size(3);

const int pad_f = pad[4];
const int pad_t = pad[6];
const int pad_l = pad[8];
int iStartX = std::max(0, -pad_l);
int iStartY = std::max(0, -pad_t);
int iStartZ = std::max(0, -pad_f);
int oStartX = std::max(0, pad_l);
int oStartY = std::max(0, pad_t);
int oStartZ = std::max(0, pad_f);

int k, ip_x, ip_y, ip_z;
#pragma omp parallel for private(k, ip_x, ip_y, ip_z)
for (k = 0; k < nslices; k++) {
int i, j, z;
for (z = 0; z < odepth; z++) {
for (i = 0; i < oheight; i++) {
for (j = 0; j < owidth; j++) {
if (j < pad_l) {
ip_x = pad_l;
} else if (j >= pad_l && j < iwidth + pad_l) {
ip_x = j;
} else {
ip_x = iwidth + pad_l - 1;
}
ip_x = ip_x - oStartX + iStartX;

if (i < pad_t) {
ip_y = pad_t;
} else if (i >= pad_t && i < iheight + pad_t) {
ip_y = i;
} else {
ip_y = iheight + pad_t - 1;
}
ip_y = ip_y - oStartY + iStartY;

if (z < pad_f) {
ip_z = pad_f;
} else if (z >= pad_f && z < idepth + pad_f) {
ip_z = z;
} else {
ip_z = idepth + pad_f - 1;
}
ip_z = ip_z - oStartZ + iStartZ;

DType *dest_p = dst.dptr_ + k * owidth * oheight * odepth +
z * owidth * oheight + i * owidth + j;
DType *src_p = src.dptr_ + k * iwidth * iheight * idepth +
ip_z * iwidth * iheight + ip_y * iwidth + ip_x;
*dest_p = *src_p;
}
}
}
}
}

template <typename DType>
void single_image_edge_grad(const Tensor<cpu, 4, DType> &grad_in,
const Tensor<cpu, 4, DType> grad_out,
mxnet::TShape pad) {
const int nslices = grad_in.size(0);
const int idepth = grad_in.size(1);
const int iheight = grad_in.size(2);
const int iwidth = grad_in.size(3);

const int odepth = grad_out.size(1);
const int oheight = grad_out.size(2);
const int owidth = grad_out.size(3);

const int pad_f = pad[4];
const int pad_t = pad[6];
const int pad_l = pad[8];
int iStartX = std::max(0, -pad_l);
int iStartY = std::max(0, -pad_t);
int iStartZ = std::max(0, -pad_f);
int oStartX = std::max(0, pad_l);
int oStartY = std::max(0, pad_t);
int oStartZ = std::max(0, pad_f);

int k, ip_x, ip_y, ip_z;
#pragma omp parallel for private(k, ip_x, ip_y, ip_z)
for (k = 0; k < nslices; k++) {
int i, j, z;
for (z = 0; z < odepth; z++) {
for (i = 0; i < oheight; i++) {
for (j = 0; j < owidth; j++) {
if (j < pad_l) {
ip_x = pad_l;
} else if (j >= pad_l && j < iwidth + pad_l) {
ip_x = j;
} else {
ip_x = iwidth + pad_l - 1;
}
ip_x = ip_x - oStartX + iStartX;

if (i < pad_t) {
ip_y = pad_t;
} else if (i >= pad_t && i < iheight + pad_t) {
ip_y = i;
} else {
ip_y = iheight + pad_t - 1;
}
ip_y = ip_y - oStartY + iStartY;

if (z < pad_f) {
ip_z = pad_f;
} else if (z >= pad_f && z < idepth + pad_f) {
ip_z = z;
} else {
ip_z = idepth + pad_f - 1;
}
ip_z = ip_z - oStartZ + iStartZ;

DType *src_p = grad_out.dptr_ + k * owidth * oheight * odepth +
z * owidth * oheight + i * owidth + j;
DType *dest_p = grad_in.dptr_ + k * iwidth * iheight * idepth +
ip_z * iwidth * iheight + ip_y * iwidth + ip_x;
*dest_p += *src_p;
}
}
}
}
}

// Case 2: Zero Padding
template <typename DType>
void pad_image_2d(const Tensor<cpu, 4, DType> &dst,
const Tensor<cpu, 4, DType> src, mxnet::TShape pad, int mode,
DType constant_value) {
void single_image_constant(const Tensor<cpu, 4, DType> &dst,
const Tensor<cpu, 4, DType> src, mxnet::TShape pad,
DType constant_value) {
const int pad_f = pad[4];
const int pad_t = pad[6];
const int pad_l = pad[8];
int c, d, w, h;
#pragma omp parallel for private(c, d, w, h)
for (c = 0; c < dst.size(0); ++c) {
for (d = 0; d < dst.size(1); ++d) {
for (h = 0; h < dst.size(2); ++h) {
for (w = 0; w < dst.size(3); ++w) {
if ((w < pad_l) || (h < pad_t) || (d < pad_f) ||
(d >= (src.size(1) + pad_f)) || (h >= (src.size(2) + pad_t)) ||
(w >= (src.size(3) + pad_l))) {
dst[c][d][h][w] = constant_value;
} else {
dst[c][d][h][w] = src[c][d - pad_f][h - pad_t][w - pad_l];
}
}
}
}
}
}

template <typename DType>
void single_image_constant_grad(const Tensor<cpu, 4, DType> &in_grad,
const Tensor<cpu, 4, DType> out_grad,
mxnet::TShape pad) {
const int pad_f = pad[4];
const int pad_t = pad[6];
const int pad_l = pad[8];
int c, d, w, h;
#pragma omp parallel for private(c, w, h)
for (c = 0; c < in_grad.size(0); ++c) {
for (d = 0; d < in_grad.size(1); ++d) {
for (h = 0; h < in_grad.size(2); ++h) {
for (w = 0; w < in_grad.size(3); ++w) {
in_grad[c][d][h][w] += out_grad[c][d + pad_f][h + pad_t][w + pad_l];
}
}
}
}
}

////////////////////////////////////////////////////////////////////////////////
// Interface to 2d and 3d image pad methods

template <int dim, typename DType>
void pad_image(const Tensor<cpu, dim, DType> &dst,
const Tensor<cpu, dim, DType> src, mxnet::TShape pad, int mode,
DType constant_value) {
for (index_t n = 0; n < dst.size(0); ++n) {
switch (mode) {
case mxnet::op::pad_enum::kEdge:
single_image_2d_edge(dst[n], src[n], pad);
single_image_edge(dst[n], src[n], pad);
break;
case mxnet::op::pad_enum::kConstant:
single_image_2d_constant(dst[n], src[n], pad, constant_value);
single_image_constant(dst[n], src[n], pad, constant_value);
break;
}
}
}

template <typename DType>
void pad_image_2d_grad(const Tensor<cpu, 4, DType> &in_grad,
const Tensor<cpu, 4, DType> out_grad, mxnet::TShape pad,
int mode) {
template <int dim, typename DType>
void pad_image_grad(const Tensor<cpu, dim, DType> &in_grad,
const Tensor<cpu, dim, DType> out_grad, mxnet::TShape pad,
int mode) {
for (index_t n = 0; n < in_grad.size(0); ++n) {
switch (mode) {
case mxnet::op::pad_enum::kEdge:
single_image_2d_edge_grad(in_grad[n], out_grad[n], pad);
single_image_edge_grad(in_grad[n], out_grad[n], pad);
break;
case mxnet::op::pad_enum::kConstant:
single_image_2d_constant_grad(in_grad[n], out_grad[n], pad);
single_image_constant_grad(in_grad[n], out_grad[n], pad);
break;
}
}
}

} // namespace mshadow

namespace mxnet {
Expand Down
Loading

0 comments on commit 9107fc3

Please sign in to comment.