Skip to content

Commit

Permalink
optimize broadcast (#8566)
Browse files Browse the repository at this point in the history
* optimize broadcast

* Update elemwise_binary_broadcast_op.h
  • Loading branch information
piiswrong committed Nov 11, 2017
1 parent 9572340 commit 23a9294
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 3 deletions.
65 changes: 65 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,37 @@ MSHADOW_XINLINE Shape<ndim> calc_stride(const Shape<ndim>& shape) {
return stride;
}

/* Increment coordinates and modify index */
template<int ndim>
MSHADOW_XINLINE void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
index_t* idx, const Shape<ndim>& stride) {
++(*coord)[ndim-1];
*idx += stride[ndim-1];
#pragma unroll
for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) {
(*coord)[i] -= shape[i];
++(*coord)[i-1];
*idx = *idx + stride[i-1] - shape[i] * stride[i];
}
}

/* Increment coordinates and modify index */
template<int ndim>
MSHADOW_XINLINE void inc(Shape<ndim>* coord, const Shape<ndim>& shape,
index_t* idx1, const Shape<ndim>& stride1,
index_t* idx2, const Shape<ndim>& stride2) {
++(*coord)[ndim-1];
*idx1 += stride1[ndim-1];
*idx2 += stride2[ndim-1];
#pragma unroll
for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) {
(*coord)[i] -= shape[i];
++(*coord)[i-1];
*idx1 = *idx1 + stride1[i-1] - shape[i] * stride1[i];
*idx2 = *idx2 + stride2[i-1] - shape[i] * stride2[i];
}
}

/*!
* \brief Simple copy data from one blob to another
* \param to Destination blob
Expand Down Expand Up @@ -355,6 +386,24 @@ struct Kernel<OP, cpu> {
for (int i = 0; i < N; ++i) {
OP::Map(i, args...);
}
#endif
}

template<typename ...Args>
inline static void LaunchEx(mshadow::Stream<cpu> *s, const int N, Args... args) {
#ifdef _OPENMP
const int omp_cores = Engine::Get()->num_omp_threads_per_worker();
if (omp_cores <= 1) {
OP::Map(0, N, args...);
} else {
int length = (N + omp_cores - 1) / omp_cores;
#pragma omp parallel for num_threads(omp_cores)
for (int i = 0; i < N; i += length) {
OP::Map(i, i + length > N ? N - i : length, args...);
}
}
#else
OP::Map(0, N, args...);
#endif
}
};
Expand All @@ -368,6 +417,13 @@ __global__ void mxnet_generic_kernel(int N, Args... args) {
}
}

template<typename OP, typename ...Args>
__global__ void mxnet_generic_kernel_ex(int N, Args... args) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
OP::Map(i, 1, args...);
}
}

template<typename OP>
struct Kernel<OP, gpu> {
template<typename ...Args>
Expand All @@ -378,6 +434,15 @@ struct Kernel<OP, gpu> {
<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
N, args...);
}

template<typename ...Args>
inline static void LaunchEx(mshadow::Stream<gpu> *s, int N, Args... args) {
using namespace mshadow::cuda;
int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum);
mxnet_generic_kernel_ex<OP, Args...>
<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
N, args...);
}
};
#endif // __CUDACC__

Expand Down
32 changes: 29 additions & 3 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,34 @@ inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshap
return j;
}

namespace mxnet_op {
template<int ndim, typename DType, typename OP>
struct binary_broadcast_kernel {
MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
const Shape<ndim>& lstride, const Shape<ndim>& rstride,
const Shape<ndim>& oshape, DType* lhs, DType* rhs,
DType* out, int lsize, int rsize) {
Shape<ndim> coord = unravel(base, oshape);
index_t lidx = dot(coord, lstride);
index_t ridx = dot(coord, rstride);
KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
// starts from 1 to avoid extra inc at end of loop
for (int i = 1; i < length; ++i) {
inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
KERNEL_ASSIGN(out[base+i], req, OP::Map(lhs[lidx], rhs[ridx]));
}
}
};

} // namespace mxnet_op

template<typename xpu, typename OP>
void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace broadcast;
using namespace mxnet_op;
TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_,
&new_lshape, &new_rshape, &new_oshape);
Expand All @@ -149,8 +170,13 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
BinaryBroadcastComputeImpl<NDim, DType, OP>(s, req[0], inputs[0].reshape(new_lshape),
inputs[1].reshape(new_rshape), outputs[0].reshape(new_oshape));
Shape<NDim> oshape = new_oshape.get<NDim>();
Shape<NDim> lstride = calc_stride(new_lshape.get<NDim>());
Shape<NDim> rstride = calc_stride(new_rshape.get<NDim>());
Kernel<binary_broadcast_kernel<NDim, DType, OP>, xpu>::LaunchEx(
s, new_oshape.Size(), req[0], lstride, rstride, oshape,
inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), outputs[0].dptr<DType>(),
inputs[0].Size(), inputs[1].Size());
});
});
}
Expand Down

0 comments on commit 23a9294

Please sign in to comment.