Skip to content

Commit 0b4f669

Browse files
authored
[AutoSchedule] Sparse dense tuning support with custom sketch rule (#7313)
* Add sparse dense tuning tutorial * Add sparse input fusion * Update the dag to support output fusion * Update * Add task input to search_task * Update * Add search_inputs to measure * Lint fix * Lint fix * Update * Update * Update * Update * Add file save load support * Update * Update * Update * Remove add_task_inputs API * Update * Update * Update * Lint fix * Lint fix * Lint fix * Lint fix * Update * Add example ci_log * Update * retrigger ci * Update * Update * Update * Lint fix * Lint fix * Lint fix
1 parent 783be9d commit 0b4f669

File tree

15 files changed

+1109
-26
lines changed

15 files changed

+1109
-26
lines changed

include/tvm/auto_scheduler/measure_record.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
namespace tvm {
3535
namespace auto_scheduler {
3636

37-
const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*)
37+
const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.6"; // NOLINT(*)
3838

3939
/*! \brief Callback for logging the input and results of measurements to file */
4040
class RecordToFileNode : public MeasureCallbackNode {

include/tvm/auto_scheduler/search_task.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_
2727

2828
#include <tvm/auto_scheduler/compute_dag.h>
29+
#include <tvm/runtime/ndarray.h>
2930
#include <tvm/target/target.h>
3031

3132
namespace tvm {
@@ -120,6 +121,8 @@ class SearchTaskNode : public Object {
120121
HardwareParams hardware_params;
121122
/*! \brief The layout rewrite option used for measuring programs. */
122123
LayoutRewriteOption layout_rewrite_option;
124+
/*! \brief Names of some user defined input data used in program measuring. */
125+
Array<String> task_input_names;
123126

124127
void VisitAttrs(tvm::AttrVisitor* v) {
125128
v->Visit("compute_dag", &compute_dag);
@@ -128,6 +131,7 @@ class SearchTaskNode : public Object {
128131
v->Visit("target_host", &target_host);
129132
v->Visit("hardware_params", &hardware_params);
130133
v->Visit("layout_rewrite_option", &layout_rewrite_option);
134+
v->Visit("task_input_names", &task_input_names);
131135
}
132136

133137
static constexpr const char* _type_key = "auto_scheduler.SearchTask";
@@ -148,9 +152,11 @@ class SearchTask : public ObjectRef {
148152
* \param target_host The target host device of this search task.
149153
* \param hardware_params Hardware parameters used in this search task.
150154
* \param layout_rewrite_option The layout rewrite option used for measuring programs.
155+
* \param task_input_names Names of some user defined input data used in program measuring.
151156
*/
152157
SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host,
153-
Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option);
158+
Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option,
159+
Array<String> task_input_names);
154160

155161
TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode);
156162
};

python/tvm/auto_scheduler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
LocalRunner,
4242
RPCRunner,
4343
LocalRPCMeasureContext,
44+
register_task_input_check_func,
4445
)
4546
from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records
4647
from .relay_integration import (

python/tvm/auto_scheduler/measure.py

Lines changed: 154 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import shutil
3737
import tempfile
3838
import multiprocessing
39+
import logging
3940

4041
import tvm._ffi
4142
from tvm.runtime import Object, module, ndarray
@@ -50,6 +51,7 @@
5051
call_func_with_timeout,
5152
check_remote,
5253
get_const_tuple,
54+
get_func_name,
5355
make_traceback_info,
5456
request_remote,
5557
)
@@ -58,6 +60,8 @@
5860
deserialize_workload_registry_entry,
5961
)
6062

63+
# pylint: disable=invalid-name
64+
logger = logging.getLogger("auto_scheduler")
6165

6266
# The time cost for measurements with errors
6367
# We use 1e10 instead of sys.float_info.max for better readability in log
@@ -223,6 +227,7 @@ def recover_measure_input(inp, rebuild_state=False):
223227
target_host=task.target_host,
224228
hardware_params=task.hardware_params,
225229
layout_rewrite_option=task.layout_rewrite_option,
230+
task_inputs=list(task.task_input_names),
226231
)
227232

228233
if rebuild_state:
@@ -719,6 +724,97 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
719724
return results
720725

721726

727+
TASK_INPUT_CHECK_FUNC_REGISTRY = {}
728+
729+
730+
def register_task_input_check_func(func_name, f=None, override=False):
731+
"""Register a function that checks the input buffer map.
732+
733+
The input function should take a list of Tensor wich indicate the Input/output Tensor of a TVM
734+
subgraph and return a Map from the input Tensor to its buffer name.
735+
736+
Parameters
737+
----------
738+
func_name : Union[Function, str]
739+
The check function that returns the compute declaration Tensors or its function name.
740+
f : Optional[Function]
741+
The check function to be registered.
742+
override : boolean = False
743+
Whether to override existing entry.
744+
745+
Examples
746+
--------
747+
.. code-block:: python
748+
749+
@auto_scheduler.register_task_input_check_func
750+
def check_task_input_by_placeholder_name(args : List[Tensor]):
751+
tensor_input_map = {}
752+
for arg in args:
753+
if isinstance(arg.op, tvm.te.PlaceholderOp):
754+
if arg.op.name != "placeholder":
755+
tensor_input_map[arg] = arg.op.name
756+
return tensor_input_map
757+
"""
758+
global TASK_INPUT_CHECK_FUNC_REGISTRY
759+
760+
if callable(func_name):
761+
f = func_name
762+
func_name = get_func_name(f)
763+
if not isinstance(func_name, str):
764+
raise ValueError("expect string function name")
765+
766+
def register(myf):
767+
"""internal register function"""
768+
if func_name in TASK_INPUT_CHECK_FUNC_REGISTRY and not override:
769+
raise RuntimeError("%s has been registered already" % func_name)
770+
TASK_INPUT_CHECK_FUNC_REGISTRY[func_name] = myf
771+
return myf
772+
773+
if f:
774+
return register(f)
775+
return register
776+
777+
778+
def _prepare_input_map(args):
779+
"""This function deals with special task inputs. Map the input Tensor of a TVM subgraph
780+
to a specific buffer name in the global buffer map.
781+
782+
Parameters
783+
----------
784+
args : List[Tensor]
785+
Input/output Tensor of a TVM subgraph.
786+
787+
Returns
788+
-------
789+
Dict[Tensor, str] :
790+
Map from the input Tensor to its buffer name.
791+
792+
Notes
793+
-----
794+
The buffer name is specially designed, and these buffer should be provided in
795+
`SearchTask(..., task_inputs={...})`.
796+
"""
797+
# pylint: disable=import-outside-toplevel
798+
799+
global TASK_INPUT_CHECK_FUNC_REGISTRY
800+
801+
# A dict that maps the input tensor arg to a buffer name
802+
tensor_input_map = {}
803+
804+
# Case 0: Check placeholder name
805+
for arg in args:
806+
if isinstance(arg.op, tvm.te.PlaceholderOp):
807+
if arg.op.name != "placeholder":
808+
tensor_input_map[arg] = arg.op.name
809+
810+
# Case 1: Check specific tensor inputs
811+
for func_name in TASK_INPUT_CHECK_FUNC_REGISTRY:
812+
func = TASK_INPUT_CHECK_FUNC_REGISTRY[func_name]
813+
tensor_input_map.update(func(args))
814+
815+
return tensor_input_map
816+
817+
722818
def _timed_eval_func(
723819
inp_serialized,
724820
build_res,
@@ -729,7 +825,11 @@ def _timed_eval_func(
729825
enable_cpu_cache_flush,
730826
verbose,
731827
):
828+
# pylint: disable=import-outside-toplevel
829+
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency
830+
732831
inp = MeasureInput.deserialize(inp_serialized)
832+
task_input_names = inp.task.task_input_names
733833
tic = time.time()
734834
error_no = 0
735835
error_msg = None
@@ -758,11 +858,31 @@ def _timed_eval_func(
758858

759859
if error_no == 0:
760860
try:
761-
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
762861
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
763862
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"
764-
for arg in args:
765-
random_fill(arg)
863+
864+
tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
865+
args = []
866+
task_inputs_count = 0
867+
for arg in build_res.args:
868+
if arg in tensor_input_map:
869+
tensor_name = tensor_input_map[arg]
870+
if tensor_name in task_input_names:
871+
args.append(get_task_input_buffer(inp.task.workload_key, tensor_name))
872+
task_inputs_count += 1
873+
else:
874+
raise ValueError(
875+
"%s not found in task_inputs, " % (tensor_name)
876+
+ "should provide with `SearchTask(..., task_inputs={...})`"
877+
)
878+
else:
879+
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
880+
random_fill(empty_array)
881+
args.append(empty_array)
882+
if task_inputs_count != len(task_input_names):
883+
logger.warning(
884+
"task_inputs not fully matched, check if there's any unexpected error"
885+
)
766886
ctx.sync()
767887
costs = time_f(*args).results
768888
# pylint: disable=broad-except
@@ -911,7 +1031,11 @@ def _timed_rpc_run(
9111031
enable_cpu_cache_flush,
9121032
verbose,
9131033
):
1034+
# pylint: disable=import-outside-toplevel
1035+
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency
1036+
9141037
inp = MeasureInput.deserialize(inp_serialized)
1038+
task_input_names = inp.task.task_input_names
9151039
tic = time.time()
9161040
error_no = 0
9171041
error_msg = None
@@ -943,18 +1067,36 @@ def _timed_rpc_run(
9431067

9441068
if error_no == 0:
9451069
try:
946-
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
947-
try:
948-
random_fill = remote.get_function("tvm.contrib.random.random_fill")
949-
except AttributeError:
950-
raise AttributeError(
951-
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
1070+
random_fill = remote.get_function("tvm.contrib.random.random_fill")
1071+
assert (
1072+
random_fill
1073+
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"
1074+
1075+
tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
1076+
args = []
1077+
task_inputs_count = 0
1078+
for arg in build_res.args:
1079+
if arg in tensor_input_map:
1080+
tensor_name = tensor_input_map[arg]
1081+
if tensor_name in task_input_names:
1082+
args.append(get_task_input_buffer(inp.task.workload_key, tensor_name))
1083+
task_inputs_count += 1
1084+
else:
1085+
raise ValueError(
1086+
"%s not found in task_inputs, " % (tensor_name)
1087+
+ "should provide with `SearchTask(..., task_inputs={...})`"
1088+
)
1089+
else:
1090+
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
1091+
random_fill(empty_array)
1092+
args.append(empty_array)
1093+
if task_inputs_count != len(task_input_names):
1094+
logger.warning(
1095+
"task_inputs not fully matched, check if there's any unexpected error"
9521096
)
953-
for arg in args:
954-
random_fill(arg)
9551097
ctx.sync()
956-
9571098
costs = time_f(*args).results
1099+
9581100
# clean up remote files
9591101
remote.remove(build_res.filename)
9601102
remote.remove(os.path.splitext(build_res.filename)[0] + ".so")

0 commit comments

Comments
 (0)