Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-798] Fix the dtype cast from non float32 in Gradient computation #12290

Merged
merged 14 commits into from
Sep 14, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Rename variables
  • Loading branch information
apeforest committed Sep 12, 2018
commit 9c09a2e21c4372bedbd71db2380bfd07bea3ee71
4 changes: 2 additions & 2 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {

if (v.empty()) {
nnvm::NodePtr ng = nnvm::Node::Create();
ng->attrs.op = Op::Get("_zeros_no_default");
ng->attrs.name = "zeros_no_default";
ng->attrs.op = Op::Get("_zeros_without_dtype");
ng->attrs.name = "zeros_without_dtype";
ng->attrs.op->attr_parser(&(ng->attrs));
return nnvm::NodeEntry{ng, 0, 0};
}
Expand Down
17 changes: 9 additions & 8 deletions src/operator/tensor/init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,22 @@ namespace op {

DMLC_REGISTER_PARAMETER(InitOpParam);
DMLC_REGISTER_PARAMETER(InitOpWithScalarParam);
DMLC_REGISTER_PARAMETER(InitOpNoDefaultParam);
DMLC_REGISTER_PARAMETER(InitOpWithoutDTypeParam);
DMLC_REGISTER_PARAMETER(RangeParam);
DMLC_REGISTER_PARAMETER(EyeParam);

NNVM_REGISTER_OP(_zeros_no_default)
.describe("fill target with zeros with no default type")
NNVM_REGISTER_OP(_zeros_without_dtype)
.describe("fill target with zeros without default dtype")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InitOpNoDefaultParam>)
.set_attr<nnvm::FInferShape>("FInferShape", InitShape<InitOpNoDefaultParam>)
.set_attr<nnvm::FInferType>("FInferType", InitType<InitOpNoDefaultParam>)
.set_attr<FInferStorageType>("FInferStorageType", InitStorageType<InitOpNoDefaultParam, true, true>)
.set_attr_parser(ParamParser<InitOpWithoutDTypeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", InitShape<InitOpWithoutDTypeParam>)
.set_attr<nnvm::FInferType>("FInferType", InitType<InitOpWithoutDTypeParam>)
.set_attr<FInferStorageType>("FInferStorageType",
InitStorageType<InitOpWithoutDTypeParam, true, true>)
.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 0>)
.set_attr<FComputeEx>("FComputeEx<cpu>", FillComputeZerosEx<cpu>)
.add_arguments(InitOpNoDefaultParam::__FIELDS__());
.add_arguments(InitOpWithoutDTypeParam::__FIELDS__());

NNVM_REGISTER_OP(_zeros)
.describe("fill target with zeros")
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ struct InitOpParam : public dmlc::Parameter<InitOpParam> {
}
};

struct InitOpNoDefaultParam : public dmlc::Parameter<InitOpNoDefaultParam> {
struct InitOpWithoutDTypeParam : public dmlc::Parameter<InitOpWithoutDTypeParam> {
TShape shape;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(InitOpNoDefaultParam) {
DMLC_DECLARE_PARAMETER(InitOpWithoutDTypeParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(TShape())
.describe("The shape of the output");
Expand Down