Skip to content

Commit

Permalink
Add TernaryOp that takes three input: lhs, mhs, rhs.
Browse files Browse the repository at this point in the history
Also, add  a ternary operator `fill_element_0index` , which fills the specific element in each line of the `lhs` according to the index in `rhs` and value in `mhs`
  • Loading branch information
sxjscience committed Jan 6, 2016
1 parent feafd75 commit 14b6dd2
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 0 deletions.
70 changes: 70 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,70 @@ DMLC_REGISTRY_ENABLE(::mxnet::NDArrayFunctionReg);
} // namespace dmlc

namespace mxnet {

/*!
* \brief run a ternary operation
* \param lhs left operand
* \param mhs middle operand
* \param rhs right operand
* \param out the output ndarray
*/
template<typename OP>
void TernaryOp(const NDArray &lhs,
const NDArray &mhs,
const NDArray &rhs,
NDArray *out) {
// no check if all of them are on cpu
if (lhs.ctx().dev_mask() != cpu::kDevMask || mhs.ctx().dev_mask() != cpu::kDevMask || rhs.ctx().dev_mask() != cpu::kDevMask) {
CHECK((lhs.ctx() == mhs.ctx()) && (mhs.ctx() == rhs.ctx())) << "operands context mismatch";
}
// if out is none, allocate space
if (out->is_none()) {
*out = NDArray(OP::GetShape(lhs.shape(), mhs.shape(), rhs.shape()), lhs.ctx(), true);
}
else {
// no check if both of them are on cpu
if (lhs.ctx().dev_mask() != cpu::kDevMask ||
out->ctx().dev_mask() != cpu::kDevMask) {
CHECK(out->ctx() == lhs.ctx()) << "target context mismatch";
}
CHECK(out->shape() == OP::GetShape(lhs.shape(), mhs.shape(), rhs.shape()))
<< "target shape mismatch";
}
// important: callback must always capture by value
NDArray ret = *out;
// get the const variables
std::vector<Engine::VarHandle> const_vars;
if (lhs.var() != ret.var()) const_vars.push_back(lhs.var());
if (mhs.var() != ret.var()) const_vars.push_back(mhs.var());
if (rhs.var() != ret.var()) const_vars.push_back(rhs.var());

// redirect everything to mshadow operations
switch (lhs.ctx().dev_mask()) {
case cpu::kDevMask: {
Engine::Get()->PushSync([lhs, mhs, rhs, ret](RunContext ctx) {
ret.CheckAndAlloc();
TBlob tmp = ret.data();
ndarray::Eval<cpu, OP>(lhs.data(), mhs.data(), rhs.data(), &tmp, ctx);
}, lhs.ctx(), const_vars, { ret.var() });
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask: {
Engine::Get()->PushSync([lhs, mhs, rhs, ret](RunContext ctx) {
ret.CheckAndAlloc();
TBlob tmp = ret.data();
ndarray::Eval<gpu, OP>(lhs.data(), mhs.data(), rhs.data(), &tmp, ctx);
// Wait GPU kernel to complete
ctx.get_stream<gpu>()->Wait();
}, lhs.ctx(), const_vars, { ret.var() });
break;
}
#endif
default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
}

/*!
* \brief run a binary operation
* \param lhs left operand
Expand Down Expand Up @@ -617,6 +681,12 @@ MXNET_REGISTER_NDARRAY_FUN(choose_element_0index)
" in lhs according to index indicated by rhs."
" This function assume rhs uses 0-based index.");

MXNET_REGISTER_NDARRAY_FUN(fill_element_0index)
.set_function(TernaryOp<ndarray::MatFillRowElem>)
.describe("Fill one element of each line(row for python, column for R/Julia)"
" in lhs according to index indicated by rhs and values indicated by mhs."
" This function assume rhs uses 0-based index.");

// register API function
// those with underscore will be registered at NDArray
MXNET_REGISTER_NDARRAY_FUN(_plus_scalar).set_function(ScalarOp<ndarray::Plus, false>);
Expand Down
21 changes: 21 additions & 0 deletions src/ndarray/ndarray_function-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@
#include "./ndarray_function.h"
// this file will be included twice by CPU and GPU
// macro to help specialize evaluation function

#ifndef DECL_TERNARY
#define DECL_TERNARY(XPU, OP, FUN) \
template<> \
void Eval<XPU, OP>(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \
FUN<XPU, OP>(lhs, mhs, rhs, ret, ctx); \
}
#endif

#ifndef DECL_BINARY
#define DECL_BINARY(XPU, OP, FUN) \
template<> \
Expand Down Expand Up @@ -75,6 +84,17 @@ inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs,
rhs.get<xpu, 1, real_t>(s));
}

template<typename xpu, typename OP>
inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs,
TBlob *ret, RunContext ctx) {
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
ret->get<xpu, 2, real_t>(s)
= mat_fill_row_element(lhs.get<xpu, 2, real_t>(s),
mhs.get<xpu, 1, real_t>(s),
rhs.get<xpu, 1, real_t>(s));
}

template<typename xpu, typename OP, bool reverse>
inline void EvalScalar_(const TBlob &lhs, const real_t &rhs,
TBlob *ret, RunContext ctx) {
Expand Down Expand Up @@ -181,6 +201,7 @@ void ElementwiseSum<DEVICE>(const std::vector<TBlob> source,

// declarations
DECL_BINARY(DEVICE, MatChooseRowElem, EvalMatChooseRowElem_)
DECL_TERNARY(DEVICE, MatFillRowElem, EvalMatFillRowElem_)
DECL_BINARY(DEVICE, Dot, EvalDot_)
DECL_BINARY(DEVICE, OneHotEncode, EvalOneHot_)
DECL_BINARY(DEVICE, Plus, EvalBinary_)
Expand Down
13 changes: 13 additions & 0 deletions src/ndarray/ndarray_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ struct MatChooseRowElem {
}
};

struct MatFillRowElem {
inline static TShape GetShape(const TShape &lshape, const TShape &mshape, const TShape &rshape) {
CHECK(lshape.ndim() == 2 && mshape.ndim() == 1 && rshape.ndim() == 1)
<< "fill_row_element only support 2D Matrix, 1D value and 1D index";
CHECK((lshape[0] == mshape[0]) && (mshape[0] == rshape[0]))
<< "choose_row_element index vector, value vector and matrix shape mismatch";
return lshape;
}
};

// type holder for random number generators
struct UniformDistribution {};

Expand All @@ -101,6 +111,9 @@ template<typename Device>
void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max,
TBlob *ret, RunContext ctx);

template<typename Device, typename OP>
void Eval(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, TBlob *ret, RunContext ctx);

template<typename Device, typename OP>
void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx);

Expand Down

0 comments on commit 14b6dd2

Please sign in to comment.