Skip to content

Commit 2e24e76

Browse files
committed
address review comments.
1 parent 9712c99 commit 2e24e76

File tree

4 files changed

+131
-107
lines changed

4 files changed

+131
-107
lines changed

python/tvm/contrib/pipeline_executor.py

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def pipeline_executor_enabled():
3232
return tvm._ffi.get_global_func("tvm.pipeline_executor.create", allow_missing=True) is not None
3333

3434

35-
def build_pipeline(mod_n_configs):
35+
def build(pipe_configs):
3636
"""build module list that can use for pipeline execution.
3737
3838
Parameters
@@ -53,9 +53,11 @@ def build_pipeline(mod_n_configs):
5353
pipeline configuration
5454
"""
5555
mods = {}
56+
mod_n_configs = pipe_configs.get_config()
5657
config_len = len(mod_n_configs)
5758
string_config = [{} for _ in range(config_len)]
58-
for _, (ir_mod, mod_config) in enumerate(mod_n_configs.items()):
59+
#for _, (ir_mod, mod_config) in enumerate(mod_n_configs.items()):
60+
for ir_mod, mod_config in mod_n_configs.items():
5961
# init lib_name and json_name params with empty
6062
lib_name = ""
6163
json_name = ""
@@ -133,15 +135,18 @@ def __init__(self, pipeline_mods, pipeline_config):
133135
self.pipeline_mods = pipeline_mods
134136
self.mod_config = pipeline_config
135137
mods, config = self.graph_executor_create(pipeline_mods, pipeline_config)
136-
137-
pipelinecreate = tvm._ffi.get_global_func("tvm.pipeline_executor.create")
138+
assert pipeline_executor_enabled(), \
139+
"Pipeline executor is not enabled. Please \
140+
re-build TVM with USE_PIPELINE_EXECUTOR=ON"
141+
pipelinecreate = tvm._ffi.get_global_func("tvm.pipeline_executor.create",
142+
allow_missing=False)
138143
assert pipelinecreate
139144
module = pipelinecreate(mods, config)
140145

141146
self.module_ = module
142147

143148
def graph_executor_create(self, pipeline_mods, mod_config):
144-
"""Create a pipeline runtime executor.
149+
"""Create graph_executor list and return string format config.
145150
146151
Parameters
147152
----------
@@ -167,19 +172,19 @@ def graph_executor_create(self, pipeline_mods, mod_config):
167172
return mods, json.dumps(mod_config)
168173

169174

170-
class PipelineModuleConfig:
175+
class PipelineConfig(object):
171176
"""Pipeline Configuration Class, in this class there are 2 internal class,
172-
first is Instance which use to represent Module, second is Interface which use
177+
first is Module which use to represent Module, second is Interface which use
173178
to represent Module input/output and Pipeline Module input/output, by setting
174179
dependency relation between Interfaces this class can build the module
175180
connection relation.
176181
177182
The class Hierarchical as following.
178-
PipelineModuleConfig ---> Pipe Instance ---> Interface(input/output)
179-
---> Module Instance ---> Interface(input/output)
183+
PipelineConfig ---> Pipeline Module ---> Interface(input/output)
184+
---> Subgraph Module ---> Interface(input/output)
180185
"""
181186

182-
class Instance:
187+
class Module:
183188
"""The class use use to represent Module and storage module index and
184189
Interface information.
185190
"""
@@ -190,7 +195,7 @@ class Interface:
190195
Parameters
191196
----------
192197
193-
owner : Instance
198+
owner : Module
194199
The class that own this interface, in such class there are
195200
Module information like index, module name
196201
@@ -251,6 +256,13 @@ def __init__(self, indx=0):
251256
self.indx_ = indx
252257
self.name_ = "mod" + str(indx) if indx else ""
253258
self.interfaces_ = {1: {}, 2: {}}
259+
self.target_host_ = None
260+
self.mod_name_ = "default"
261+
self.build_func_ = None
262+
self.params_ = None
263+
self.target_ = None
264+
self.dev_ = None
265+
254266

255267
def get_interface(self, itype, name):
256268
if name not in self.interfaces_[itype]:
@@ -264,23 +276,41 @@ def input(self, name):
264276
def output(self, index):
265277
return self.get_interface(2, index)
266278

279+
def set_target_host(self, host):
280+
self.target_host_ = host
281+
282+
def set_mod_name(self, name):
283+
self.mod_name_ = name
284+
285+
def set_build_func(self, build_func):
286+
self.build_func_ = build_func
287+
288+
def set_params(self, params):
289+
self.params_ = params
290+
291+
def set_target(self, target):
292+
self.target_ = target
293+
294+
def set_dev(self, dev):
295+
self.dev_ = dev
296+
267297
def __init__(self, mods):
268-
self.pipe_instance = self.Instance(0)
269-
self.mod_instance = {m: self.Instance(i + 1) for m, i in zip(mods, range(len(mods)))}
298+
self.pipe_module = self.Module(0)
299+
self.mod_module = {m: self.Module(i + 1) for m, i in zip(mods, range(len(mods)))}
270300

271301
def __str__(self):
272302
""" Get configuration in string type"""
273303
# get input
274304
input_dump = "Inputs\n"
275-
for input_name in self.pipe_instance.interfaces_[1]:
276-
inf = self.pipe_instance.interfaces_[1][input_name]
305+
for input_name in self.pipe_module.interfaces_[1]:
306+
inf = self.pipe_module.interfaces_[1][input_name]
277307
input_dump += " |" + input_name + ": " + inf.get_dependent_str() + "\n"
278308

279309
# get connections
280310
output = {}
281311
connections_dump = "\nconnections\n"
282-
for mod in self.mod_instance:
283-
for _, interface in self.mod_instance[mod].interfaces_[2].items():
312+
for mod in self.mod_module:
313+
for _, interface in self.mod_module[mod].interfaces_[2].items():
284314
if interface.dependent_:
285315
mname, dname = interface.get_name()
286316
iname = mname + ".output(" + dname + ")->"
@@ -302,11 +332,12 @@ def __str__(self):
302332
def get_config(self):
303333
""" Get configuration in dictionary format."""
304334
mconfig = {}
305-
for mod in self.mod_instance:
335+
for mod in self.mod_module:
336+
# get pipeline configure
306337
mconf = {}
307338
output_conf = []
308-
instance = self.mod_instance[mod]
309-
for _, interface in instance.interfaces_[2].items():
339+
module = self.mod_module[mod]
340+
for _, interface in module.interfaces_[2].items():
310341
dep_conf = []
311342
output = {}
312343
if interface.dependent_:
@@ -317,29 +348,43 @@ def get_config(self):
317348
dep_item["input_name"] = dname
318349
dep_conf.append(dep_item)
319350

320-
# in configuration the ouput_indx start from 0.
351+
# ouput_indx start from 0.
321352

322353
output["output_indx"] = int(interface.name_)
323354
output["dependent"] = dep_conf
324355
output_conf.append(output)
325-
mconf["mod_indx"] = instance.indx_
356+
mconf["mod_indx"] = module.indx_
326357
mconf["output"] = output_conf
327-
mconfig[mod] = {"pipeline": mconf}
358+
359+
# build module configuration with pipeline and other parameters.
360+
mconfig[mod] = {"pipeline": mconf,
361+
"target_host": module.target_host_,
362+
"mod_name": module.mod_name_,
363+
"build": module.build_func_,
364+
"params": module.params_,
365+
"target": module.target_,
366+
"dev": module.dev_,
367+
}
328368

329369
return mconfig
330370

331371
def __getitem__(self, key):
332-
return self.mod_instance[key]
372+
return self.mod_module[key]
333373

334374
def get_mod_indx(self, mod):
335-
indx = self.mod_instance[mod].indx_
375+
indx = self.mod_module[mod].indx_
336376
return indx
337377

338378
def pipe_input(self, name):
339-
return self.pipe_instance.input(name)
379+
return self.pipe_module.input(name)
340380

341381
def pipe_output(self, index):
342-
return self.pipe_instance.output(index)
382+
return self.pipe_module.output(index)
343383

344-
def connect(self, left: Instance.Interface, right: Instance.Interface):
384+
def connect(self, left: Module.Interface, right: Module.Interface):
345385
left.add_dependent(right)
386+
387+
388+
def PipeModuleConfig(object):
389+
def __init__(self):
390+
return

src/runtime/pipeline/pipeline_executor.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,6 @@ void SubGraphRuntime::Init(const Array<tvm::runtime::Module>& modules,
3030
return;
3131
}
3232

33-
PackedFunc SubGraphRuntime::GetFunction(const std::string& name,
34-
const ObjectPtr<Object>& sptr_to_self) {
35-
return PackedFunc();
36-
}
37-
3833
Module PipelineRuntimeCreate(const Array<tvm::runtime::Module>& m,
3934
const std::string& pipeline_json) {
4035
auto exec = make_object<SubGraphRuntime>();

src/runtime/pipeline/pipeline_executor.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,6 @@ class TVM_DLL SubGraphRuntime : public ModuleNode {
5959
* which is not compatible with RPCModules.
6060
*/
6161
void Init(const Array<tvm::runtime::Module>& modules, const std::string& pipeline_json);
62-
/*!
63-
* \brief Get member function to front-end
64-
* \param name The name of the function.
65-
* \param sptr_to_self The pointer to the module node.
66-
* \return The corresponding member function.
67-
*/
68-
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
6962
};
7063
} // namespace runtime
7164
} // namespace tvm

0 commit comments

Comments
 (0)