diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index 3f9585ca..14f1f795 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -23,6 +23,8 @@ namespace op { namespace softmaxout_enum { enum SoftmaxOutputOpInputs {kData, kLabel}; enum SoftmaxOutputOpOutputs {kOut}; +enum SoftmaxOutputNormType {kNull, kBatch, kValid}; +enum SoftmaxOutputOpResource {kTempSpace}; } // namespace softmaxout_enum struct SoftmaxOutputParam : public dmlc::Parameter { @@ -30,6 +32,7 @@ struct SoftmaxOutputParam : public dmlc::Parameter { float ignore_label; bool multi_output; bool use_ignore; + int normalization; DMLC_DECLARE_PARAMETER(SoftmaxOutputParam) { DMLC_DECLARE_FIELD(grad_scale).set_default(1.0f) .describe("Scale the gradient by a float factor"); @@ -43,6 +46,14 @@ struct SoftmaxOutputParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(use_ignore).set_default(false) .describe("If set to true, the ignore_label value will not contribute " "to the backward gradient"); + DMLC_DECLARE_FIELD(normalization) + .add_enum("null", softmaxout_enum::kNull) + .add_enum("batch", softmaxout_enum::kBatch) + .add_enum("valid", softmaxout_enum::kValid) + .set_default(softmaxout_enum::kNull) + .describe("If set to null, op will do nothing on output gradient." + "If set to batch, op will normalize gradient by divide batch size" + "If set to valid, op will normalize gradient by divide sample not ignored"); }; }; @@ -91,6 +102,7 @@ class SoftmaxOutputOp : public Operator { CHECK_GE(in_grad.size(), 1); CHECK_GE(req.size(), 1); Stream *s = ctx.get_stream(); + if (param_.multi_output) { int n = out_data[softmaxout_enum::kOut].size(0); int k = out_data[softmaxout_enum::kOut].size(1); @@ -100,24 +112,65 @@ class SoftmaxOutputOp : public Operator { out_data[softmaxout_enum::kOut].get_with_shape(s3, s); Tensor grad = in_grad[softmaxout_enum::kData].get_with_shape(s3, s); + + index_t valid_cnt = label.shape_.Size(); if (param_.use_ignore) { SoftmaxGrad(grad, out, label, static_cast(param_.ignore_label)); } else { SoftmaxGrad(grad, out, label); } - grad *= DType(param_.grad_scale/s3[2]); + if (param_.normalization == softmaxout_enum::kBatch) { + valid_cnt = label.size(0); + } else if (param_.normalization == softmaxout_enum::kValid) { + int i_label = static_cast(param_.ignore_label); + Tensor workspace = + ctx.requested[softmaxout_enum::kTempSpace].get_host_space_typed<2, DType>( + label.shape_); + Copy(workspace, label, label.stream_); + for (index_t i = 0; i < workspace.size(0); ++i) { + for (index_t j = 0; j < workspace.size(1); ++j) { + if (static_cast(workspace[i][j]) == i_label) { + valid_cnt--; + } + } + } + valid_cnt = valid_cnt == 0 ? 1 : valid_cnt; + } else { + valid_cnt = 1; + } + grad *= DType(param_.grad_scale / + (param_.normalization == softmaxout_enum::kValid ? 1 : s3[2]) / + valid_cnt); } else { const TShape& label_shape = in_data[softmaxout_enum::kLabel].shape_; Tensor label = in_data[softmaxout_enum::kLabel].get_with_shape( Shape1(label_shape.ProdShape(0, label_shape.ndim())), s); Tensor out = out_data[softmaxout_enum::kOut].FlatTo2D(s); Tensor grad = in_grad[softmaxout_enum::kData].FlatTo2D(s); + index_t valid_cnt = label.shape_.Size(); if (param_.use_ignore) { SoftmaxGrad(grad, out, label, static_cast(param_.ignore_label)); } else { SoftmaxGrad(grad, out, label); } - grad *= DType(param_.grad_scale); + if (param_.normalization == softmaxout_enum::kBatch) { + valid_cnt = label.size(0); + } else if (param_.normalization == softmaxout_enum::kValid) { + int i_label = static_cast(param_.ignore_label); + Tensor workspace = + ctx.requested[softmaxout_enum::kTempSpace].get_host_space_typed<1, DType>( + label.shape_); + Copy(workspace, label, label.stream_); + for (index_t i = 0; i < label.size(0); ++i) { + if (static_cast(workspace[i]) == i_label) { + valid_cnt--; + } + } + valid_cnt = valid_cnt == 0 ? 1 : valid_cnt; + } else { + valid_cnt = 1; + } + grad *= DType(param_.grad_scale / valid_cnt); } } @@ -216,6 +269,11 @@ class SoftmaxOutputProp : public OperatorProperty { return {{in_data[softmaxout_enum::kData], out_data[softmaxout_enum::kOut]}}; } + std::vector BackwardResource( + const std::vector &in_shape) const override { + return {ResourceRequest::kTempSpace}; + } + Operator* CreateOperator(Context ctx) const override { LOG(FATAL) << "Not Implemented."; return NULL;