77#include < nnvm/node.h>
88#include < nnvm/op_attr_types.h>
99#include < nnvm/top/nn.h>
10- #include " ./op_common.h"
11- #include " ./elemwise_op_common.h"
10+ #include " .. /op_common.h"
11+ #include " .. /elemwise_op_common.h"
1212
1313namespace nnvm {
1414namespace top {
@@ -126,6 +126,25 @@ NNVM_REGISTER_OP(dropout)
126126// batchnorm
127127DMLC_REGISTER_PARAMETER (BatchNormParam);
128128
129+ inline bool BatchNormInferShape (const nnvm::NodeAttrs& attrs,
130+ std::vector<TShape> *in_shape,
131+ std::vector<TShape> *out_shape) {
132+ CHECK_EQ (in_shape->size (), 5U )
133+ << " Input:[data, gamma, beta, moving_mean, moving_var]" ;
134+ CHECK_EQ (out_shape->size (), 3U );
135+ const TShape &dshape = in_shape->at (0 );
136+ if (dshape.ndim () == 0 ) return false ;
137+ TShape bshape ({dshape[1 ]});
138+ NNVM_ASSIGN_INPUT_SHAPE (attrs, *in_shape, 1 , bshape);
139+ NNVM_ASSIGN_INPUT_SHAPE (attrs, *in_shape, 2 , bshape);
140+ NNVM_ASSIGN_INPUT_SHAPE (attrs, *in_shape, 3 , bshape);
141+ NNVM_ASSIGN_INPUT_SHAPE (attrs, *in_shape, 4 , bshape);
142+ NNVM_ASSIGN_OUTPUT_SHAPE (attrs, *out_shape, 0 , dshape);
143+ NNVM_ASSIGN_OUTPUT_SHAPE (attrs, *out_shape, 1 , bshape);
144+ NNVM_ASSIGN_OUTPUT_SHAPE (attrs, *out_shape, 2 , bshape);
145+ return true ;
146+ }
147+
129148NNVM_REGISTER_OP (batch_norm)
130149.describe(R"( Batch normalization layer (Ioffe and Szegedy, 2014).
131150Normalizes the input at each batch, i.e. applies a transformation
@@ -167,6 +186,8 @@ axis to be the last item in the input shape.
167186.set_num_inputs(5 )
168187.set_num_outputs(3 )
169188.set_attr_parser(ParamParser<BatchNormParam>)
189+ .set_attr<FInferShape>(" FInferShape" , BatchNormInferShape)
190+ .set_attr<FInferType>(" FInferType" , ElemwiseType<5 , 3 >)
170191.set_attr<FListInputNames>(" FListInputNames" , [](const NodeAttrs& attrs) {
171192 return std::vector<std::string>{" data" , " gamma" , " beta" , " moving_mean" , " moving_var" };
172193 })
@@ -198,8 +219,6 @@ NNVM_REGISTER_OP(softmax)
198219.set_support_level(1 );
199220
200221// log_softmax
201- DMLC_REGISTER_PARAMETER (LogSoftmaxParam);
202-
203222NNVM_REGISTER_OP (log_softmax)
204223.describe(R"code( Computes softmax.
205224
@@ -208,7 +227,23 @@ NNVM_REGISTER_OP(log_softmax)
208227)code" NNVM_ADD_FILELINE)
209228.set_num_inputs(1 )
210229.set_num_outputs(1 )
211- .set_attr_parser(ParamParser<LogSoftmaxParam>)
230+ .set_attr_parser(ParamParser<SoftmaxParam>)
231+ .set_attr<FInferShape>(" FInferShape" , ElemwiseShape<1 , 1 >)
232+ .set_attr<FInferType>(" FInferType" , ElemwiseType<1 , 1 >)
233+ .set_support_level(1 );
234+
235+ // leaky_rlu
236+ DMLC_REGISTER_PARAMETER (LeakyReLUParam);
237+
238+ NNVM_REGISTER_OP (leaky_relu)
239+ .describe(R"code( Leaky version of a Rectified Linear Unit.
240+
241+ `y = x > 0 ? x : alpha * x`
242+
243+ )code" NNVM_ADD_FILELINE)
244+ .set_num_inputs(1 )
245+ .set_num_outputs(1 )
246+ .set_attr_parser(ParamParser<LeakyReLUParam>)
212247.set_attr<FInferShape>(" FInferShape" , ElemwiseShape<1 , 1 >)
213248.set_attr<FInferType>(" FInferType" , ElemwiseType<1 , 1 >)
214249.set_support_level(1 );
0 commit comments