Skip to content

Commit 1b2e5ce

Browse files
authored
[RUNTIME] Remove parameter def from runtime (#486)
1 parent b18143e commit 1b2e5ce

File tree

3 files changed

+36
-13
lines changed

3 files changed

+36
-13
lines changed

src/runtime/graph/graph_runtime.cc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,37 @@ class GraphRuntime : public ModuleNode {
156156
// control deps
157157
std::vector<uint32_t> control_deps;
158158
// JSON Loader
159+
void LoadAttrs(dmlc::JSONReader *reader, TVMOpParam* param) {
160+
int bitmask = 0;
161+
std::string key, value;
162+
reader->BeginObject();
163+
while (reader->NextObjectItem(&key)) {
164+
if (key == "func_name") {
165+
reader->Read(&value);
166+
param->func_name = value;
167+
bitmask |= 1;
168+
} else if (key == "num_inputs") {
169+
reader->Read(&value);
170+
std::istringstream is(value);
171+
is >> param->num_inputs;
172+
bitmask |= 2;
173+
} else if (key == "num_outputs") {
174+
reader->Read(&value);
175+
std::istringstream is(value);
176+
is >> param->num_outputs;
177+
bitmask |= 4;
178+
} else if (key == "flatten_data") {
179+
reader->Read(&value);
180+
std::istringstream is(value);
181+
is >> param->flatten_data;
182+
bitmask |= 8;
183+
} else {
184+
reader->Read(&value);
185+
}
186+
}
187+
CHECK_EQ(bitmask, 1|2|4|8) << "invalid format";
188+
}
189+
// JSON Loader
159190
void Load(dmlc::JSONReader *reader) {
160191
reader->BeginObject();
161192
std::unordered_map<std::string, std::string> dict;
@@ -172,8 +203,7 @@ class GraphRuntime : public ModuleNode {
172203
reader->Read(&inputs);
173204
bitmask |= 4;
174205
} else if (key == "attr" || key == "attrs") {
175-
reader->Read(&dict);
176-
param.Init(dict);
206+
this->LoadAttrs(reader, &param);
177207
} else if (key == "control_deps") {
178208
reader->Read(&control_deps);
179209
} else {
@@ -263,6 +293,8 @@ class GraphRuntime : public ModuleNode {
263293
} else if (key == "attrs") {
264294
reader->Read(&attrs_);
265295
bitmask |= 16;
296+
} else {
297+
LOG(FATAL) << "key " << key << " is not supported";
266298
}
267299
}
268300
CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format";
@@ -320,7 +352,6 @@ class GraphRuntime : public ModuleNode {
320352
std::vector<std::function<void()> > op_execs_;
321353
};
322354

323-
DMLC_REGISTER_PARAMETER(TVMOpParam);
324355

325356
bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
326357
uint64_t header, reserved;

src/runtime/graph/graph_runtime.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
99
#define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
1010

11-
#include <dmlc/parameter.h>
1211
#include <string>
1312

1413
namespace tvm {
@@ -20,18 +19,11 @@ constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
2019
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
2120

2221
/*! \brief operator attributes about tvm op */
23-
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
22+
struct TVMOpParam {
2423
std::string func_name;
2524
uint32_t num_inputs;
2625
uint32_t num_outputs;
2726
uint32_t flatten_data;
28-
29-
DMLC_DECLARE_PARAMETER(TVMOpParam) {
30-
DMLC_DECLARE_FIELD(func_name);
31-
DMLC_DECLARE_FIELD(num_inputs).set_default(1);
32-
DMLC_DECLARE_FIELD(num_outputs).set_default(1);
33-
DMLC_DECLARE_FIELD(flatten_data).set_default(0);
34-
}
3527
};
3628

3729
} // namespace runtime

0 commit comments

Comments
 (0)