36
36
import shutil
37
37
import tempfile
38
38
import multiprocessing
39
+ import logging
39
40
40
41
import tvm ._ffi
41
42
from tvm .runtime import Object , module , ndarray
50
51
call_func_with_timeout ,
51
52
check_remote ,
52
53
get_const_tuple ,
54
+ get_func_name ,
53
55
make_traceback_info ,
54
56
request_remote ,
55
57
)
58
60
deserialize_workload_registry_entry ,
59
61
)
60
62
63
+ # pylint: disable=invalid-name
64
+ logger = logging .getLogger ("auto_scheduler" )
61
65
62
66
# The time cost for measurements with errors
63
67
# We use 1e10 instead of sys.float_info.max for better readability in log
@@ -223,6 +227,7 @@ def recover_measure_input(inp, rebuild_state=False):
223
227
target_host = task .target_host ,
224
228
hardware_params = task .hardware_params ,
225
229
layout_rewrite_option = task .layout_rewrite_option ,
230
+ task_inputs = list (task .task_input_names ),
226
231
)
227
232
228
233
if rebuild_state :
@@ -719,6 +724,97 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
719
724
return results
720
725
721
726
727
+ TASK_INPUT_CHECK_FUNC_REGISTRY = {}
728
+
729
+
730
+ def register_task_input_check_func (func_name , f = None , override = False ):
731
+ """Register a function that checks the input buffer map.
732
+
733
+ The input function should take a list of Tensor wich indicate the Input/output Tensor of a TVM
734
+ subgraph and return a Map from the input Tensor to its buffer name.
735
+
736
+ Parameters
737
+ ----------
738
+ func_name : Union[Function, str]
739
+ The check function that returns the compute declaration Tensors or its function name.
740
+ f : Optional[Function]
741
+ The check function to be registered.
742
+ override : boolean = False
743
+ Whether to override existing entry.
744
+
745
+ Examples
746
+ --------
747
+ .. code-block:: python
748
+
749
+ @auto_scheduler.register_task_input_check_func
750
+ def check_task_input_by_placeholder_name(args : List[Tensor]):
751
+ tensor_input_map = {}
752
+ for arg in args:
753
+ if isinstance(arg.op, tvm.te.PlaceholderOp):
754
+ if arg.op.name != "placeholder":
755
+ tensor_input_map[arg] = arg.op.name
756
+ return tensor_input_map
757
+ """
758
+ global TASK_INPUT_CHECK_FUNC_REGISTRY
759
+
760
+ if callable (func_name ):
761
+ f = func_name
762
+ func_name = get_func_name (f )
763
+ if not isinstance (func_name , str ):
764
+ raise ValueError ("expect string function name" )
765
+
766
+ def register (myf ):
767
+ """internal register function"""
768
+ if func_name in TASK_INPUT_CHECK_FUNC_REGISTRY and not override :
769
+ raise RuntimeError ("%s has been registered already" % func_name )
770
+ TASK_INPUT_CHECK_FUNC_REGISTRY [func_name ] = myf
771
+ return myf
772
+
773
+ if f :
774
+ return register (f )
775
+ return register
776
+
777
+
778
+ def _prepare_input_map (args ):
779
+ """This function deals with special task inputs. Map the input Tensor of a TVM subgraph
780
+ to a specific buffer name in the global buffer map.
781
+
782
+ Parameters
783
+ ----------
784
+ args : List[Tensor]
785
+ Input/output Tensor of a TVM subgraph.
786
+
787
+ Returns
788
+ -------
789
+ Dict[Tensor, str] :
790
+ Map from the input Tensor to its buffer name.
791
+
792
+ Notes
793
+ -----
794
+ The buffer name is specially designed, and these buffer should be provided in
795
+ `SearchTask(..., task_inputs={...})`.
796
+ """
797
+ # pylint: disable=import-outside-toplevel
798
+
799
+ global TASK_INPUT_CHECK_FUNC_REGISTRY
800
+
801
+ # A dict that maps the input tensor arg to a buffer name
802
+ tensor_input_map = {}
803
+
804
+ # Case 0: Check placeholder name
805
+ for arg in args :
806
+ if isinstance (arg .op , tvm .te .PlaceholderOp ):
807
+ if arg .op .name != "placeholder" :
808
+ tensor_input_map [arg ] = arg .op .name
809
+
810
+ # Case 1: Check specific tensor inputs
811
+ for func_name in TASK_INPUT_CHECK_FUNC_REGISTRY :
812
+ func = TASK_INPUT_CHECK_FUNC_REGISTRY [func_name ]
813
+ tensor_input_map .update (func (args ))
814
+
815
+ return tensor_input_map
816
+
817
+
722
818
def _timed_eval_func (
723
819
inp_serialized ,
724
820
build_res ,
@@ -729,7 +825,11 @@ def _timed_eval_func(
729
825
enable_cpu_cache_flush ,
730
826
verbose ,
731
827
):
828
+ # pylint: disable=import-outside-toplevel
829
+ from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency
830
+
732
831
inp = MeasureInput .deserialize (inp_serialized )
832
+ task_input_names = inp .task .task_input_names
733
833
tic = time .time ()
734
834
error_no = 0
735
835
error_msg = None
@@ -758,11 +858,31 @@ def _timed_eval_func(
758
858
759
859
if error_no == 0 :
760
860
try :
761
- args = [ndarray .empty (get_const_tuple (x .shape ), x .dtype , ctx ) for x in build_res .args ]
762
861
random_fill = tvm .get_global_func ("tvm.contrib.random.random_fill" , True )
763
862
assert random_fill , "Please make sure USE_RANDOM is ON in the config.cmake"
764
- for arg in args :
765
- random_fill (arg )
863
+
864
+ tensor_input_map = _prepare_input_map (build_res .args ) if task_input_names else {}
865
+ args = []
866
+ task_inputs_count = 0
867
+ for arg in build_res .args :
868
+ if arg in tensor_input_map :
869
+ tensor_name = tensor_input_map [arg ]
870
+ if tensor_name in task_input_names :
871
+ args .append (get_task_input_buffer (inp .task .workload_key , tensor_name ))
872
+ task_inputs_count += 1
873
+ else :
874
+ raise ValueError (
875
+ "%s not found in task_inputs, " % (tensor_name )
876
+ + "should provide with `SearchTask(..., task_inputs={...})`"
877
+ )
878
+ else :
879
+ empty_array = ndarray .empty (get_const_tuple (arg .shape ), arg .dtype , ctx )
880
+ random_fill (empty_array )
881
+ args .append (empty_array )
882
+ if task_inputs_count != len (task_input_names ):
883
+ logger .warning (
884
+ "task_inputs not fully matched, check if there's any unexpected error"
885
+ )
766
886
ctx .sync ()
767
887
costs = time_f (* args ).results
768
888
# pylint: disable=broad-except
@@ -911,7 +1031,11 @@ def _timed_rpc_run(
911
1031
enable_cpu_cache_flush ,
912
1032
verbose ,
913
1033
):
1034
+ # pylint: disable=import-outside-toplevel
1035
+ from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency
1036
+
914
1037
inp = MeasureInput .deserialize (inp_serialized )
1038
+ task_input_names = inp .task .task_input_names
915
1039
tic = time .time ()
916
1040
error_no = 0
917
1041
error_msg = None
@@ -943,18 +1067,36 @@ def _timed_rpc_run(
943
1067
944
1068
if error_no == 0 :
945
1069
try :
946
- args = [ndarray .empty (get_const_tuple (x .shape ), x .dtype , ctx ) for x in build_res .args ]
947
- try :
948
- random_fill = remote .get_function ("tvm.contrib.random.random_fill" )
949
- except AttributeError :
950
- raise AttributeError (
951
- "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
1070
+ random_fill = remote .get_function ("tvm.contrib.random.random_fill" )
1071
+ assert (
1072
+ random_fill
1073
+ ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"
1074
+
1075
+ tensor_input_map = _prepare_input_map (build_res .args ) if task_input_names else {}
1076
+ args = []
1077
+ task_inputs_count = 0
1078
+ for arg in build_res .args :
1079
+ if arg in tensor_input_map :
1080
+ tensor_name = tensor_input_map [arg ]
1081
+ if tensor_name in task_input_names :
1082
+ args .append (get_task_input_buffer (inp .task .workload_key , tensor_name ))
1083
+ task_inputs_count += 1
1084
+ else :
1085
+ raise ValueError (
1086
+ "%s not found in task_inputs, " % (tensor_name )
1087
+ + "should provide with `SearchTask(..., task_inputs={...})`"
1088
+ )
1089
+ else :
1090
+ empty_array = ndarray .empty (get_const_tuple (arg .shape ), arg .dtype , ctx )
1091
+ random_fill (empty_array )
1092
+ args .append (empty_array )
1093
+ if task_inputs_count != len (task_input_names ):
1094
+ logger .warning (
1095
+ "task_inputs not fully matched, check if there's any unexpected error"
952
1096
)
953
- for arg in args :
954
- random_fill (arg )
955
1097
ctx .sync ()
956
-
957
1098
costs = time_f (* args ).results
1099
+
958
1100
# clean up remote files
959
1101
remote .remove (build_res .filename )
960
1102
remote .remove (os .path .splitext (build_res .filename )[0 ] + ".so" )
0 commit comments