Skip to content

Commit

Permalink
make slice operator 20x faster on GPU (apache#11124)
Browse files Browse the repository at this point in the history
* gpu slice kernel

* remove unused line
  • Loading branch information
eric-haibin-lin authored and piiswrong committed Jun 2, 2018
1 parent 7eb78d8 commit a068fae
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 21 deletions.
1 change: 0 additions & 1 deletion src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
#include "../mxnet_op.h"
#include "./sort_op.h"
#include "./init_op.h"
#include "./matrix_op-inl.h"
#include "../../engine/openmp.h"

namespace mxnet {
Expand Down
130 changes: 110 additions & 20 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -706,12 +706,42 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
return oshape.ndim() != 0 && oshape.Size() != 0;
}

template<int ndim>
struct slice_forward {
template<int ndim, int req, typename xpu>
struct slice_forward;

template<int ndim, int req>
struct slice_forward<ndim, req, gpu> {
// i is the i-th row after flattening out into 2D tensor
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data,
const mshadow::Shape<ndim> dshape,
const mshadow::Shape<ndim> oshape,
const common::StaticArray<int, ndim> begin,
const common::StaticArray<int, ndim> step) {
const int data_last_dim_size = dshape[ndim-1];
const int out_last_dim_size = oshape[ndim-1];
const int step_last_dim = step[ndim-1];
const int begin_last_dim = begin[ndim-1];
const int j = i % out_last_dim_size;
int irow = 0; // row id of flattend 2D data
int stride = 1;
int idx = i / out_last_dim_size;
#pragma unroll
for (int k = ndim - 2; k >= 0; --k) {
irow += stride * ((idx % oshape[k]) * step[k] + begin[k]);
idx /= oshape[k];
stride *= dshape[k];
}
KERNEL_ASSIGN(out[i], req,
data[irow * data_last_dim_size + j * step_last_dim + begin_last_dim]);
}
};

template<int ndim, int req>
struct slice_forward<ndim, req, cpu> {
// i is the i-th row after flattening out into 2D tensor
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data,
const OpReqType req,
const mshadow::Shape<ndim> dshape,
const mshadow::Shape<ndim> oshape,
const common::StaticArray<int, ndim> begin,
Expand Down Expand Up @@ -756,19 +786,27 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
common::StaticArray<int, ndim> begin, end, step;
GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step);
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
mxnet_op::Kernel<slice_forward<ndim>, xpu>::Launch(s, out.shape_.FlatTo2D()[0],
out.dptr<DType>(), data.dptr<DType>(), req[0],
data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
int num_threads = out.shape_.FlatTo2D()[0];
if (std::is_same<xpu, gpu>::value) {
num_threads *= out.shape_.get<ndim>()[ndim - 1];
}
mxnet_op::Kernel<slice_forward<ndim, Req, xpu>, xpu>::Launch(s, num_threads,
out.dptr<DType>(), data.dptr<DType>(),
data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
})
})
})
}

template<int ndim>
struct slice_assign {
template<int ndim, int req, typename xpu>
struct slice_assign;

template<int ndim, int req>
struct slice_assign<ndim, req, cpu> {
// i is the i-th row after flattening out into 2D tensor
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* val,
const OpReqType req,
const mshadow::Shape<ndim> oshape,
const mshadow::Shape<ndim> vshape,
const common::StaticArray<int, ndim> begin,
Expand All @@ -794,6 +832,34 @@ struct slice_assign {
}
};

template<int ndim, int req>
struct slice_assign<ndim, req, gpu> {
// i is the i-th row after flattening out into 2D tensor
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* val,
const mshadow::Shape<ndim> oshape,
const mshadow::Shape<ndim> vshape,
const common::StaticArray<int, ndim> begin,
const common::StaticArray<int, ndim> step) {
const int data_last_dim_size = oshape[ndim-1];
const int out_last_dim_size = vshape[ndim-1];
const int step_last_dim = step[ndim-1];
const int begin_last_dim = begin[ndim-1];
const int j = i % out_last_dim_size;
int irow = 0; // row id of flattend 2D out
int stride = 1;
int idx = i / out_last_dim_size;
#pragma unroll
for (int k = ndim - 2; k >= 0; --k) {
irow += stride * ((idx % vshape[k]) * step[k] + begin[k]);
idx /= vshape[k];
stride *= oshape[k];
}
KERNEL_ASSIGN(out[irow * data_last_dim_size + j * step_last_dim + begin_last_dim],
req, val[i]);
}
};

template<typename xpu>
void SliceOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -818,9 +884,15 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs,
common::StaticArray<int, ndim> begin, end, step;
GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step);
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
mxnet_op::Kernel<slice_assign<ndim>, xpu>::Launch(s, ograd.shape_.FlatTo2D()[0],
igrad.dptr<DType>(), ograd.dptr<DType>(), req[0],
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
int num_threads = ograd.shape_.FlatTo2D()[0];
if (std::is_same<xpu, gpu>::value) {
num_threads *= ograd.shape_.get<ndim>()[ndim - 1];
}
mxnet_op::Kernel<slice_assign<ndim, Req, xpu>, xpu>::Launch(s, num_threads,
igrad.dptr<DType>(), ograd.dptr<DType>(),
igrad.shape_.get<ndim>(), ograd.shape_.get<ndim>(), begin, step);
})
})
})
}
Expand Down Expand Up @@ -876,9 +948,15 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs,
common::StaticArray<int, ndim> begin, end, step;
GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step);
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
mxnet_op::Kernel<slice_assign<ndim>, xpu>::Launch(s, val.shape_.FlatTo2D()[0],
out.dptr<DType>(), val.dptr<DType>(), req[0],
out.shape_.get<ndim>(), val.shape_.get<ndim>(), begin, step);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
int num_threads = val.shape_.FlatTo2D()[0];
if (std::is_same<xpu, gpu>::value) {
num_threads *= val.shape_.get<ndim>()[ndim - 1];
}
mxnet_op::Kernel<slice_assign<ndim, Req, xpu>, xpu>::Launch(s, num_threads,
out.dptr<DType>(), val.dptr<DType>(),
out.shape_.get<ndim>(), val.shape_.get<ndim>(), begin, step);
})
})
})
}
Expand Down Expand Up @@ -1242,9 +1320,15 @@ void SliceLikeForward(const nnvm::NodeAttrs& attrs,
common::StaticArray<int, ndim> begin, end, step;
GetIndexRange(data.shape_, param_begin, param_end, param_step, &begin, &end, &step);
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
mxnet_op::Kernel<slice_forward<ndim>, xpu>::Launch(s, out.shape_.FlatTo2D()[0],
out.dptr<DType>(), data.dptr<DType>(), req[0],
data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
int num_threads = out.shape_.FlatTo2D()[0];
if (std::is_same<xpu, gpu>::value) {
num_threads *= out.shape_.get<ndim>()[ndim - 1];
}
mxnet_op::Kernel<slice_forward<ndim, Req, xpu>, xpu>::Launch(s,
num_threads, out.dptr<DType>(), data.dptr<DType>(),
data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
})
})
})
}
Expand Down Expand Up @@ -1282,9 +1366,15 @@ void SliceLikeBackward(const nnvm::NodeAttrs& attrs,
common::StaticArray<int, ndim> begin, end, step;
GetIndexRange(ograd.shape_, param_begin, param_end, param_step, &begin, &end, &step);
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
mxnet_op::Kernel<slice_assign<ndim>, xpu>::Launch(s, ograd.shape_.FlatTo2D()[0],
igrad.dptr<DType>(), ograd.dptr<DType>(), req[0],
igrad.shape_.get<ndim>(), ograd.shape_.get<ndim>(), begin, step);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
int num_threads = ograd.shape_.FlatTo2D()[0];
if (std::is_same<xpu, gpu>::value) {
num_threads *= ograd.shape_.get<ndim>()[ndim - 1];
}
mxnet_op::Kernel<slice_assign<ndim, Req, xpu>, xpu>::Launch(s, num_threads,
igrad.dptr<DType>(), ograd.dptr<DType>(),
igrad.shape_.get<ndim>(), ograd.shape_.get<ndim>(), begin, step);
})
})
})
}
Expand Down

0 comments on commit a068fae

Please sign in to comment.