Skip to content

Commit

Permalink
Softmax Activation (apache#2022)
Browse files Browse the repository at this point in the history
* Softmax Activation
  • Loading branch information
antinucleon committed May 3, 2016
1 parent e502e2d commit 896fba9
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
35 changes: 24 additions & 11 deletions src/operator/softmax_activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ enum SoftmaxActivationOpType {kInstance, kChannel};

struct SoftmaxActivationParam : public dmlc::Parameter<SoftmaxActivationParam> {
// use int for enumeration
int type;
int mode;
DMLC_DECLARE_PARAMETER(SoftmaxActivationParam) {
DMLC_DECLARE_FIELD(type)
DMLC_DECLARE_FIELD(mode)
.add_enum("instance", softmax_activation::kInstance)
.add_enum("channel", softmax_activation::kChannel)
.set_default(softmax_activation::kInstance)
Expand Down Expand Up @@ -63,10 +63,22 @@ class SoftmaxActivationOp : public Operator {
using namespace mshadow::expr;
CHECK_EQ(in_data.size(), 1);
CHECK_EQ(out_data.size(), 1);
// Stream<xpu> *s = ctx.get_stream<xpu>();
// Tensor<xpu, 2> data = in_data[softmax_activation::kData].FlatTo2D<xpu, real_t>(s);
// Tensor<xpu, 2> out = out_data[softmax_activation::kOut].FlatTo2D<xpu, real_t>(s);
LOG(FATAL) << "non-cuDNN version not implemented yet.";
Stream<xpu> *s = ctx.get_stream<xpu>();
if (param_.mode == softmax_activation::kInstance) {
Tensor<xpu, 2> data = in_data[softmax_activation::kData].FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> out = out_data[softmax_activation::kOut].FlatTo2D<xpu, real_t>(s);
Softmax(out, data);
} else {
CHECK_EQ(in_data[softmax_activation::kData].ndim(), 4);
TShape src_shape = in_data[softmax_activation::kData].shape_;
Shape<3> dst_shape = Shape3(src_shape[0], src_shape[1],
src_shape[2] * src_shape[3]);
Tensor<xpu, 3> data =
in_data[softmax_activation::kData].get_with_shape<xpu, 3, real_t>(dst_shape, s);
Tensor<xpu, 3> out =
out_data[softmax_activation::kOut].get_with_shape<xpu, 3, real_t>(dst_shape, s);
Softmax(out, data);
}
}

virtual void Backward(const OpContext &ctx,
Expand All @@ -81,11 +93,12 @@ class SoftmaxActivationOp : public Operator {
CHECK_EQ(out_grad.size(), 1);
CHECK(in_data.size() == 1 && in_grad.size() == 1);
CHECK_EQ(req.size(), 1);
// Stream<xpu> *s = ctx.get_stream<xpu>();
// Tensor<xpu, 2> m_out_grad = out_grad[softmax_activation::kOut].FlatTo2D<xpu, real_t>(s);
// Tensor<xpu, 2> m_out_data = out_data[softmax_activation::kOut].FlatTo2D<xpu, real_t>(s);
// Tensor<xpu, 2> m_in_grad = in_grad[softmax_activation::kData].FlatTo2D<xpu, real_t>(s);
LOG(FATAL) << "non-cuDNN version not implemented yet.";
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 2> m_out_grad = out_grad[softmax_activation::kOut].FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> m_out_data = out_data[softmax_activation::kOut].FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> m_in_grad = in_grad[softmax_activation::kData].FlatTo2D<xpu, real_t>(s);
Assign(m_in_grad, req[softmax_activation::kData],
m_out_grad * m_out_data * (1.0f - m_out_data));
}

private:
Expand Down
2 changes: 0 additions & 2 deletions src/operator/softmax_activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>(SoftmaxActivationParam param) {
LOG(FATAL) << "Softmax Activation for internal layers is only supported "
"on GPU with cuDNN. Use SoftmaxOutput for loss layer.";
return new SoftmaxActivationOp<cpu>(param);
}

Expand Down
2 changes: 0 additions & 2 deletions src/operator/softmax_activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ Operator *CreateOp<gpu>(SoftmaxActivationParam param) {
#if MXNET_USE_CUDNN == 1
return new CuDNNSoftmaxActivationOp(param);
#else
LOG(FATAL) << "Softmax Activation for internal layers is only supported "
"on GPU with cuDNN. Use SoftmaxOutput for loss layer.";
return new SoftmaxActivationOp<gpu>(param);
#endif // MXNET_USE_CUDNN
}
Expand Down

0 comments on commit 896fba9

Please sign in to comment.