diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 5b8e109d..564ad814 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -222,6 +222,37 @@ MSHADOW_XINLINE Shape calc_stride(const Shape& shape) { return stride; } +/* Increment coordinates and modify index */ +template +MSHADOW_XINLINE void inc(Shape* coord, const Shape& shape, + index_t* idx, const Shape& stride) { + ++(*coord)[ndim-1]; + *idx += stride[ndim-1]; + #pragma unroll + for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) { + (*coord)[i] -= shape[i]; + ++(*coord)[i-1]; + *idx = *idx + stride[i-1] - shape[i] * stride[i]; + } +} + +/* Increment coordinates and modify index */ +template +MSHADOW_XINLINE void inc(Shape* coord, const Shape& shape, + index_t* idx1, const Shape& stride1, + index_t* idx2, const Shape& stride2) { + ++(*coord)[ndim-1]; + *idx1 += stride1[ndim-1]; + *idx2 += stride2[ndim-1]; + #pragma unroll + for (int i = ndim - 1; i > 0 && (*coord)[i] >= shape[i]; --i) { + (*coord)[i] -= shape[i]; + ++(*coord)[i-1]; + *idx1 = *idx1 + stride1[i-1] - shape[i] * stride1[i]; + *idx2 = *idx2 + stride2[i-1] - shape[i] * stride2[i]; + } +} + /*! * \brief Simple copy data from one blob to another * \param to Destination blob @@ -355,6 +386,24 @@ struct Kernel { for (int i = 0; i < N; ++i) { OP::Map(i, args...); } +#endif + } + + template + inline static void LaunchEx(mshadow::Stream *s, const int N, Args... args) { +#ifdef _OPENMP + const int omp_cores = Engine::Get()->num_omp_threads_per_worker(); + if (omp_cores <= 1) { + OP::Map(0, N, args...); + } else { + int length = (N + omp_cores - 1) / omp_cores; + #pragma omp parallel for num_threads(omp_cores) + for (int i = 0; i < N; i += length) { + OP::Map(i, i + length > N ? N - i : length, args...); + } + } +#else + OP::Map(0, N, args...); #endif } }; @@ -368,6 +417,13 @@ __global__ void mxnet_generic_kernel(int N, Args... args) { } } +template +__global__ void mxnet_generic_kernel_ex(int N, Args... args) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { + OP::Map(i, 1, args...); + } +} + template struct Kernel { template @@ -378,6 +434,15 @@ struct Kernel { <<::GetStream(s)>>>( N, args...); } + + template + inline static void LaunchEx(mshadow::Stream *s, int N, Args... args) { + using namespace mshadow::cuda; + int ngrid = std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum); + mxnet_generic_kernel_ex + <<::GetStream(s)>>>( + N, args...); + } }; #endif // __CUDACC__ diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 7aae9cc8..1aab7146 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -133,13 +133,34 @@ inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshap return j; } +namespace mxnet_op { +template +struct binary_broadcast_kernel { + MSHADOW_XINLINE static void Map(int base, int length, OpReqType req, + const Shape& lstride, const Shape& rstride, + const Shape& oshape, DType* lhs, DType* rhs, + DType* out, int lsize, int rsize) { + Shape coord = unravel(base, oshape); + index_t lidx = dot(coord, lstride); + index_t ridx = dot(coord, rstride); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (int i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + KERNEL_ASSIGN(out[base+i], req, OP::Map(lhs[lidx], rhs[ridx])); + } + } +}; + +} // namespace mxnet_op + template void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - using namespace broadcast; + using namespace mxnet_op; TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_rshape, &new_oshape); @@ -149,8 +170,13 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, mshadow::Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - BinaryBroadcastComputeImpl(s, req[0], inputs[0].reshape(new_lshape), - inputs[1].reshape(new_rshape), outputs[0].reshape(new_oshape)); + Shape oshape = new_oshape.get(); + Shape lstride = calc_stride(new_lshape.get()); + Shape rstride = calc_stride(new_rshape.get()); + Kernel, xpu>::LaunchEx( + s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr(), + inputs[0].Size(), inputs[1].Size()); }); }); }