Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime][Pipeline Executor] Add the map logic of global input and subgraph input. #9751

Merged
merged 4 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 114 additions & 24 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,26 @@ def build(pipe_configs):
Common interface for pipeline executor factory modules.
"""
libs = {}
mod_n_configs = pipe_configs.get_config()
config = pipe_configs.get_config()
if "module_connection" not in config:
raise RuntimeError('"module_connection" is missing')
if "input_connection" not in config:
raise RuntimeError('"input_connection" is missing')

mod_n_configs = config["module_connection"]
config_len = len(mod_n_configs)
string_config = [{} for _ in range(config_len)]
module_string_config = [{} for _ in range(config_len)]
# Use hardware configurations to build backend modules for each subgraph.
for ir_mod, mod_config in mod_n_configs.items():
mconf = mod_config["pipeline"].copy()
mod_idx = mconf["mod_idx"]
pipe_config = mod_config["pipeline"].copy()
mod_idx = pipe_config["mod_idx"]
dev = mod_config["dev"]
target = mod_config["target"]
build_func = relay.build
# Check whether there is a customized build function.
# Callers may need to use a customized building function to wrap the pre-building logic
# and the backend building logic. For example, in order to support a backend which only
# can do "int8" computation, the caller may need to merge the "quantization" logic
# into the building logic to creat a customized building function.
if "build" in mod_config and mod_config["build"]:
build_func = mod_config["build"]

Expand All @@ -70,11 +80,20 @@ def build(pipe_configs):
mod_name=mod_config["mod_name"],
)

mconf["dev"] = "{},{}".format(dev.device_type, dev.device_id)
# Create a pipeline configuration.
string_config[mod_idx] = mconf
pipe_config["dev"] = "{},{}".format(dev.device_type, dev.device_id)
# Use "mod_idx" as the key to create a "module_connection" map which is not only
# for the module index but also for the module connection used to build the pipeline.
module_string_config[mod_idx] = pipe_config
libs[mod_idx] = {"lib": lib, "dev": dev}

# Creating a text form configuration to record the "input_connection" and the
# "module_connection" information. The "input_connection" is used to record the
# map of global input and subgraph input, and the "module_connection" is used to
# record module dependency.
string_config = {}
string_config["input_connection"] = config["input_connection"]
string_config["module_connection"] = module_string_config

return PipelineExecutorFactoryModule(libs, string_config)


Expand All @@ -94,6 +113,17 @@ def __init__(self, module):
self.module = module
# Get the packed functions from the pipeline executor.
self._get_num_outputs = self.module["get_num_outputs"]
self._get_input_pipeline_map = self.module["get_input_pipeline_map"]

def get_input_pipeline_map(self, name):
"""Using the "name" to get the corresponding subgraph index and also get the "input name"
of the corresponding subgraph interface.
Returns
-------
input map: Array[str]
Returning the index and "input name" of the subgraph.
"""
return self._get_input_pipeline_map(name)

@property
def num_outputs(self):
Expand Down Expand Up @@ -199,12 +229,48 @@ def is_pipeline_executor_interface(self):
return not isinstance(self.io_owner, PipelineConfig.ModuleWrapper)

def __repr__(self):
# Get all binding information.
ret = " |{}: ".format(self.name)
# Geting the binding information in the form of text.
str_format = " |{}: ".format(self.name)
for binding in self.bindings:
mname, dname = binding.get_name()
ret += "{0}:{1} ".format(mname, dname)
return ret
str_format += "{0}:{1} ".format(mname, dname)

return str_format

def check_binding_dict(self, connection_dict):
"""Checking the binding dictionary.
Parameter
---------
connection_dict : Dict[str, Any]
It is a dictionary of module connections.
"""
if "interface_name" not in connection_dict:
raise RuntimeError('"inteface_name" is missing in global config!"')
if "connection" not in connection_dict:
raise RuntimeError(f'"connection" is missing!"')
# The global interface mapping should be one-to-one.
if not connection_dict["connection"]:
raise RuntimeError("The global interface map is empty!")
if len(connection_dict["connection"]) > 1:
raise RuntimeError("A global interface maps multiple module interfaces!")
if "mod_idx" not in connection_dict["connection"][0]:
raise RuntimeError('"mod_idx" is missing!')

def get_binding_dict(self):
"""Returning the binding information in the form of dictionary.
Returns
-------
data : Dict[str, Any]
The binding information is in the form of dictionary.
"""
dict_format = {"interface_name": self.name, "connection": []}
for binding in self.bindings:
_, dname = binding.get_name()
midx = binding.get_owner_idx()
dict_format["connection"].append({"mod_idx": midx, "interface_name": dname})

self.check_binding_dict(dict_format)
return dict_format

def check_dag_acyclic(self, start, inputs):
"""This is to check whether the DAG containing these input interfaces is acyclic.
Expand Down Expand Up @@ -243,30 +309,34 @@ def connect(self, binding):

# Check whether the binding setting is correct or not.
if self.io_owner == binding.io_owner:
raise RuntimeError(f"Can not bind itself.")
raise RuntimeError("Can not bind itself.")

if not self.is_pipeline_executor_interface() and self.io_type == "input":
raise RuntimeError(f"Module can only bind from output interface!")
raise RuntimeError("Module can only bind from output interface!")

if (
not self.is_pipeline_executor_interface()
and not binding.is_pipeline_executor_interface()
and binding.io_type == "output"
):
raise RuntimeError(f"Can not bind module output with another module output!")
raise RuntimeError("Can not bind module output with another module output!")

if (
not self.is_pipeline_executor_interface()
and binding.is_pipeline_executor_interface()
and binding.io_type == "input"
):
raise RuntimeError(f"Can not bind module output with pipeline input!")
raise RuntimeError("Can not bind module output with pipeline input!")

if self.is_pipeline_executor_interface() and self.io_type == "output":
raise RuntimeError(f"Global output can not be used as binding start point.")
raise RuntimeError("Global output can not be used as binding start point.")

if self.is_pipeline_executor_interface() and binding.io_type != "input":
raise RuntimeError(f"Global input can only bind with module input.")
if (
self.is_pipeline_executor_interface()
and self.io_type == "input"
and binding.io_type != "input"
):
raise RuntimeError("Global input can only bind with module input.")

self.bindings.append(binding)
if not self.is_pipeline_executor_interface():
Expand All @@ -288,7 +358,7 @@ def connect(self, binding):
if not self.check_dag_acyclic(
binding.io_owner, self.io_owner.input_bindings.bindings
):
raise RuntimeError(f"Illegal connection: Cause a cycle!")
raise RuntimeError("Illegal connection: Cause a cycle!")

class BindingList:
"""Container for bindings(input or output interface).
Expand Down Expand Up @@ -357,7 +427,9 @@ def __getitem__(self, key):
if key == "output":
return self.output_bindings

raise RuntimeError(f"{key} not found!")
raise RuntimeError(f"{key} not found!")

raise RuntimeError('The data type of "key" is not supported!')

def get_data_type(self, key, interface_type):
"""Get the module interface data type according to the key value and interface type.
Expand Down Expand Up @@ -468,6 +540,8 @@ def get_config(self):
# Use topological sort to get the correct order of modules.
self.dag_topology_sort()
mconfig = {}
module_connection = {}
input_connection = {}
for mod in self.mod_wrapper:
# Generate pipeline configuration.
mconf = {}
Expand Down Expand Up @@ -495,7 +569,7 @@ def get_config(self):
mconf["mod_idx"] = module.idx
mconf["output"] = output_conf

mconfig[mod] = {
module_connection[mod] = {
"pipeline": mconf,
"target_host": module.target_host,
"mod_name": "default",
Expand All @@ -505,6 +579,22 @@ def get_config(self):
"dev": module.dev,
}

# Create a map of pipeline input and subgraph input.
input_connection = []
for input_name in self.input_bindings.bindings:
input_dict = self.input_bindings.bindings[input_name].get_binding_dict()
if "interface_name" not in input_dict["connection"][0]:
raise RuntimeError("interface_name is missing in connection config!")
# Creating the map of global interface and subgraph interface.
input_map = {
"global_interface_name": input_dict["interface_name"],
"mod_idx": input_dict["connection"][0]["mod_idx"],
"module_interface_name": input_dict["connection"][0]["interface_name"],
}
input_connection.append(input_map)

mconfig["module_connection"] = module_connection
mconfig["input_connection"] = input_connection
return mconfig

def dag_topology_sort(self):
Expand Down Expand Up @@ -601,11 +691,11 @@ def export_library(self, directory_path):
Export the files to this directory.
"""
if not self.pipeline_mods:
raise RuntimeError(f"The pipeline executor has not been initialized.")
raise RuntimeError("The pipeline executor has not been initialized.")

# Check if the directory_path exists.
if not os.path.exists(directory_path):
raise RuntimeError(f"The directory {directory_path} does not exist.")
raise RuntimeError("The directory {directory_path} does not exist.")
# Create an load configuration.
load_config_file_name = "{}/load_config".format(directory_path)
pipeline_config_file_name = "{}/pipeline_config".format(directory_path)
Expand Down
25 changes: 22 additions & 3 deletions src/runtime/pipeline/pipeline_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,32 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
if (name == "get_num_outputs") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); });
} else if (name == "get_input_pipeline_map") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
*rv = this->GetInputPipeplineMapping(args[0].operator String());
} else {
LOG(FATAL) << "Function only support the input name value in the form of string";
}
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc();
}
return nullptr;
}

/*!
* \brief Using the global input name to get the index, and also get the input interface name
of corresponding subgraph from the input connection configuration.
* \param The global input name.
* \return Returning the index and the input interface name of corresponding subgraph.
*/
Array<String> PipelineExecutor::GetInputPipeplineMapping(std::string input_name) {
std::pair<int, std::string> map = input_connection_config[input_name];
return {std::to_string(map.first), map.second};
}

/*!
* \brief Use the mod_config information to create a graph runtime list.
* \param mod_config The config information that generates by the export library function call.
Expand Down Expand Up @@ -108,11 +127,11 @@ void PipelineExecutor::Init(const std::vector<Module>& modules, const std::strin
// Use JSONReader to load pipeline configuration.
std::istringstream is(pipeline_json);
dmlc::JSONReader reader(&is);
PipelineConfig& pipeline_config = this->LoadPipelineConfig(&reader);
ICHECK(!pipeline_config.Empty()) << "The pipeline config information is empty.";
this->LoadConfig(&reader);
ICHECK(!pipeline_config_.Empty()) << "The pipeline config information is empty.";
// Initialize the pipeline function class used for pipeline thread pool management
// and schedule etc. This function returns the number of output.
num_outputs_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config);
num_outputs_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config_);
return;
}

Expand Down
48 changes: 23 additions & 25 deletions src/runtime/pipeline/pipeline_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
#ifndef TVM_RUNTIME_PIPELINE_PIPELINE_EXECUTOR_H_
#define TVM_RUNTIME_PIPELINE_PIPELINE_EXECUTOR_H_

#include <tvm/relay/expr.h>
#include <tvm/runtime/registry.h>

#include <array>
#include <iostream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include "pipeline_scheduler.h"
Expand Down Expand Up @@ -67,7 +69,13 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
* \return The corresponding packed function.
*/
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);

/*!
* \brief Using the global input name to get the index, and also get the input interface name
of corresponding subgraph from the input connection configuration.
* \param The global input name.
* \return Returning the index and the input interface name of corresponding subgraph.
*/
Array<String> GetInputPipeplineMapping(std::string input_name);
/*!
* \brief Get the number of outputs.
*
Expand Down Expand Up @@ -115,37 +123,27 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
/*!\brief The class used to execute and schedule the pipeline logic.*/
PipelineScheduler pipeline_scheduler_;
/*!\brief The dependency information of each graph runtime module of the pipeline.*/
PipelineConfig pipeline_config_;
ConfigPipelineExecution pipeline_config_;
/*!\brief The map of global input and subgraph input.*/
InputConnectionConfig input_connection_config;
/*!\brief The module information used to create the graph runtimes.*/
ModuleConfig mod_config_;
/*!\brief How many outputs are in this pipeline executor.*/
size_t num_outputs_ = 0;
/*!\brief Json loader.*/
PipelineConfig& LoadPipelineConfig(dmlc::JSONReader* reader) {
reader->BeginArray();
while (reader->NextArrayItem()) {
std::string key;
reader->BeginObject();
int mod_idx = -1;
OutputMap output;
std::string dev;
while (reader->NextObjectItem(&key)) {
if (key == "mod_idx") {
reader->Read(&mod_idx);
} else if (key == "dev") {
reader->Read(&dev);
} else if (key == "output") {
reader->Read(&output);
} else {
LOG(FATAL) << "do not support key " << key;
}
void LoadConfig(dmlc::JSONReader* reader) {
reader->BeginObject();
std::string key;
while (reader->NextObjectItem(&key)) {
if (key == "module_connection") {
reader->Read(&pipeline_config_);
} else if (key == "input_connection") {
reader->Read(&input_connection_config);
} else {
LOG(FATAL) << "do not support key " << key;
}
ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx;
// Check if the output is successfully read.
ICHECK(!output.Empty()) << "Invalid output binding result.";
pipeline_config_.Insert(mod_idx, output);
}
return pipeline_config_;
return;
}
};
} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/pipeline/pipeline_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace runtime {
* \param pipeline_conf The dependency information of each graph executor module.
*/
size_t PipelineScheduler::PipelineInit(const std::vector<Module>& modules,
const PipelineConfig& pipeline_config) {
const ConfigPipelineExecution& pipeline_config) {
graph_modules_ = modules;
int num_output = pipeline_config.GetGlobalOutputNum();
return num_output;
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/pipeline/pipeline_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class PipelineScheduler {
* \param modules The list of graph executor module.
* \param pipeline_config The dependency information of each graph executor module.
*/
size_t PipelineInit(const std::vector<Module>& modules, const PipelineConfig& pipeline_config);
size_t PipelineInit(const std::vector<Module>& modules,
const ConfigPipelineExecution& pipeline_config);

private:
/*!\brief The list of graph executors.*/
Expand Down
Loading