@@ -156,6 +156,37 @@ class GraphRuntime : public ModuleNode {
156156 // control deps
157157 std::vector<uint32_t > control_deps;
158158 // JSON Loader
159+ void LoadAttrs (dmlc::JSONReader *reader, TVMOpParam* param) {
160+ int bitmask = 0 ;
161+ std::string key, value;
162+ reader->BeginObject ();
163+ while (reader->NextObjectItem (&key)) {
164+ if (key == " func_name" ) {
165+ reader->Read (&value);
166+ param->func_name = value;
167+ bitmask |= 1 ;
168+ } else if (key == " num_inputs" ) {
169+ reader->Read (&value);
170+ std::istringstream is (value);
171+ is >> param->num_inputs ;
172+ bitmask |= 2 ;
173+ } else if (key == " num_outputs" ) {
174+ reader->Read (&value);
175+ std::istringstream is (value);
176+ is >> param->num_outputs ;
177+ bitmask |= 4 ;
178+ } else if (key == " flatten_data" ) {
179+ reader->Read (&value);
180+ std::istringstream is (value);
181+ is >> param->flatten_data ;
182+ bitmask |= 8 ;
183+ } else {
184+ reader->Read (&value);
185+ }
186+ }
187+ CHECK_EQ (bitmask, 1 |2 |4 |8 ) << " invalid format" ;
188+ }
189+ // JSON Loader
159190 void Load (dmlc::JSONReader *reader) {
160191 reader->BeginObject ();
161192 std::unordered_map<std::string, std::string> dict;
@@ -172,8 +203,7 @@ class GraphRuntime : public ModuleNode {
172203 reader->Read (&inputs);
173204 bitmask |= 4 ;
174205 } else if (key == " attr" || key == " attrs" ) {
175- reader->Read (&dict);
176- param.Init (dict);
206+ this ->LoadAttrs (reader, ¶m);
177207 } else if (key == " control_deps" ) {
178208 reader->Read (&control_deps);
179209 } else {
@@ -263,6 +293,8 @@ class GraphRuntime : public ModuleNode {
263293 } else if (key == " attrs" ) {
264294 reader->Read (&attrs_);
265295 bitmask |= 16 ;
296+ } else {
297+ LOG (FATAL) << " key " << key << " is not supported" ;
266298 }
267299 }
268300 CHECK_EQ (bitmask, 1 |2 |4 |8 |16 ) << " invalid format" ;
@@ -320,7 +352,6 @@ class GraphRuntime : public ModuleNode {
320352 std::vector<std::function<void ()> > op_execs_;
321353};
322354
323- DMLC_REGISTER_PARAMETER (TVMOpParam);
324355
325356bool GraphRuntime::LoadDLTensor (dmlc::Stream* strm, DLTensor* tensor) {
326357 uint64_t header, reserved;
0 commit comments