Skip to content

Commit 1ffe511

Browse files
merrymercycomaniac
authored andcommitted
[TUTORIAL][ANSOR] Using the template-free auto-scheduler on CPU (apache#6488)
* add tutorial * add tutorial * update * Apply suggestions from code review Co-authored-by: Cody Yu <comaniac0422@gmail.com> * address comments * fix bugs * add the exmple for resuming the search * fix lint Co-authored-by: Cody Yu <comaniac0422@gmail.com>
1 parent 34465f3 commit 1ffe511

File tree

10 files changed

+269
-25
lines changed

10 files changed

+269
-25
lines changed

docs/api/python/auto_scheduler.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
tvm.auto_scheduler
19+
------------------
20+
.. automodule:: tvm.auto_scheduler
21+
22+
tvm.auto_scheduler.auto_schedule
23+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
24+
.. automodule:: tvm.auto_scheduler.auto_schedule
25+
26+
.. autoclass:: tvm.auto_scheduler.auto_schedule.SearchTask
27+
28+
.. autoclass:: tvm.auto_scheduler.auto_schedule.TuningOptions
29+
30+
.. autofunction:: tvm.auto_scheduler.auto_schedule.create_task
31+
32+
.. autofunction:: tvm.auto_scheduler.auto_schedule.auto_schedule
33+
34+
35+

docs/api/python/autotvm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
tvm.autotvm
1919
-----------
2020
.. automodule:: tvm.autotvm
21-
.. automodule:: tvm.autotvm.apply_history_best
21+
.. autofunction:: tvm.autotvm.apply_history_best
2222

2323
tvm.autotvm.measure
2424
~~~~~~~~~~~~~~~~~~~

docs/api/python/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Python API
4040
relay/dataflow_pattern
4141
relay/testing
4242
autotvm
43+
auto_scheduler
4344
rpc
4445
micro
4546
contrib

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@
193193
"../tutorials/language",
194194
"../tutorials/optimize",
195195
"../tutorials/autotvm",
196+
"../tutorials/auto_scheduler",
196197
"../tutorials/dev",
197198
"../tutorials/topi",
198199
"../tutorials/deployment",

python/tvm/auto_scheduler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from . import feature
2727

2828
# Shortcut
29-
from .auto_schedule import SearchTask, TuningOptions, HardwareParams, auto_schedule
29+
from .auto_schedule import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
3030
from .compute_dag import ComputeDAG
3131
from .cost_model import RandomModel, XGBModel
3232
from .measure import (

python/tvm/auto_scheduler/auto_schedule.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131
import tvm._ffi
3232
from tvm.runtime import Object
3333
from .measure import LocalBuilder, LocalRunner
34-
from .search_policy import EmptyPolicy
34+
from .workload_registry import make_workload_key
35+
from .compute_dag import ComputeDAG
36+
from .cost_model import XGBModel
37+
from .search_policy import SketchPolicy
3538
from . import _ffi_api
3639

3740

@@ -89,26 +92,26 @@ class TuningOptions(Object):
8992
Parameters
9093
----------
9194
num_measure_trials: int = 0
92-
The number of measurement trials.
93-
The search policy measures `num_measure_trials` schedules in total and returns the best one
94-
among them.
95-
With `num_measure_trials` == 0, the policy will do the schedule search but won't involve
96-
measurement. This can be used to get a runnable schedule quickly without auto-tuning.
95+
The number of measurement trials.
96+
The search policy measures `num_measure_trials` schedules in total and returns the best one
97+
among them.
98+
With `num_measure_trials` == 0, the policy will do the schedule search but won't involve
99+
measurement. This can be used to get a runnable schedule quickly without auto-tuning.
97100
early_stopping: Optional[int]
98-
Stop the tuning early if getting no improvement after n measurements.
101+
Stop the tuning early if getting no improvement after n measurements.
99102
num_measures_per_round: int = 64
100-
The number of schedules to be measured at each search round.
101-
The whole schedule search process will try a total number of `num_measure_trials` in several
102-
rounds.
103+
The number of schedules to be measured at each search round.
104+
The whole schedule search process will try a total number of `num_measure_trials` in several
105+
rounds.
103106
verbose: int = 1
104-
Verbosity level. 0 for silent, 1 to output information during schedule search.
107+
Verbosity level. 0 for silent, 1 to output information during schedule search.
105108
builder: Union[ProgramBuilder, str] = 'local'
106-
ProgramBuilder which builds the program.
109+
ProgramBuilder which builds the program.
107110
runner: Union[ProgramRunner, str] = 'local'
108-
ProgramRunner which runs the program and measures time costs.
111+
ProgramRunner which runs the program and measures time costs.
109112
measure_callbacks: Optional[List[MeasureCallback]]
110-
Callback functions called after each measurement.
111-
Candidates:
113+
Callback functions called after each measurement.
114+
Candidates:
112115
- auto_scheduler.RecordToFile
113116
"""
114117

@@ -156,16 +159,41 @@ def __init__(
156159
)
157160

158161

162+
def create_task(func, args, target, target_host=None, hardware_params=None):
163+
"""Create a search task
164+
165+
Parameters
166+
----------
167+
func : Union[Function, str]
168+
The function that returns the compute declaration Tensors.
169+
Can be the a function or the function name.
170+
args : Union[Tuple[Any, ...], List[Any]]
171+
The args of the function.
172+
target : tvm.target.Target
173+
The target device of this search task.
174+
target_host : Optional[tvm.target.Target]
175+
The target host device of this search task.
176+
hardware_params : Optional[HardwareParams]
177+
Hardware parameters used in this search task.
178+
179+
Returns
180+
-------
181+
SearchTask: the created task
182+
"""
183+
workload_key = make_workload_key(func, args)
184+
dag = ComputeDAG(workload_key)
185+
return SearchTask(dag, workload_key, target, target_host, hardware_params)
186+
187+
159188
def auto_schedule(task, search_policy=None, tuning_options=TuningOptions()):
160-
"""Do auto scheduling for a computation declaration.
189+
"""Run auto scheduling search for a task
161190
162191
Parameters
163192
----------
164193
task : SearchTask
165194
The SearchTask for the computation declaration.
166195
search_policy : Optional[SearchPolicy]
167-
The search policy to be used for schedule search. Use EmptyPolicy as default, which always
168-
returns an empty schedule.
196+
The search policy to be used for schedule search.
169197
tuning_options : Optional[TuningOptions]
170198
Tuning and measurement options.
171199
@@ -178,5 +206,9 @@ def auto_schedule(task, search_policy=None, tuning_options=TuningOptions()):
178206
"Invalid task: " + task + " . `auto_scheduler.auto_schedule` expects a SearchTask."
179207
)
180208

181-
sch, tensors = _ffi_api.AutoSchedule(search_policy or EmptyPolicy(task), tuning_options)
209+
if search_policy is None:
210+
cost_model = XGBModel()
211+
search_policy = SketchPolicy(task, cost_model)
212+
213+
sch, tensors = _ffi_api.AutoSchedule(search_policy, tuning_options)
182214
return sch, tensors

src/auto_scheduler/search_policy/sketch_policy_rules.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ PopulationGenerationRule::ResultKind MutateComputeLocationCommon(SketchPolicyNod
593593

594594
PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy,
595595
State* state) const {
596-
return MutateComputeLocationCommon(policy, state, false);
596+
return MutateComputeLocationCommon(policy, state, true);
597597
}
598598

599599
PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy,
@@ -1059,7 +1059,7 @@ PopulationGenerationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNo
10591059

10601060
PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
10611061
State* state) const {
1062-
return MutateComputeLocationCommon(policy, state, true);
1062+
return MutateComputeLocationCommon(policy, state, false);
10631063
}
10641064

10651065
PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
AutoScheduler : Template-free Auto Scheduling
2+
---------------------------------------------
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
Auto-scheduling matrix multiplication for CPU
19+
=============================================
20+
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, \
21+
`Chengfan Jia <https://github.com/jcf94/>`_
22+
23+
Different from the existing :ref:`autotvm <tutorials-autotvm-sec>` which relies on
24+
manual templates to define the search space, the auto-scheduler does not require any templates.
25+
The auto-scheduler is template-free, so users only need to write the computation declaration without
26+
any schedule commands or templates.
27+
The auto-scheduler can automatically generate a large
28+
search space and find a good schedule in the space.
29+
30+
We use matrix multiplication as an example in this tutorial.
31+
"""
32+
33+
import numpy as np
34+
import tvm
35+
from tvm import te, testing, auto_scheduler
36+
37+
######################################################################
38+
# Define the computation
39+
# ^^^^^^^^^^^^^^^^^^^^^^
40+
# To begin with, we define the computation of a matmul with bias add.
41+
# The function should return the list of input/output tensors.
42+
# From these tensors, the auto-scheduler can get the whole computational graph.
43+
44+
45+
@auto_scheduler.register_workload
46+
def matmul_add(N, L, M, dtype):
47+
A = te.placeholder((N, L), name="A", dtype=dtype)
48+
B = te.placeholder((L, M), name="B", dtype=dtype)
49+
C = te.placeholder((N, M), name="C", dtype=dtype)
50+
51+
k = te.reduce_axis((0, L), name="k")
52+
matmul = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul")
53+
out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")
54+
55+
return [A, B, C, out]
56+
57+
58+
######################################################################
59+
# Create the search task
60+
# ^^^^^^^^^^^^^^^^^^^^^^
61+
# We then create a search task with N=L=M=128 and dtype="float32"
62+
63+
target = tvm.target.Target("llvm")
64+
task = auto_scheduler.create_task(matmul_add, (128, 128, 128, "float32"), target)
65+
66+
# Inspect the computational graph
67+
print(task.compute_dag)
68+
69+
######################################################################
70+
# Next, we set parameters for the auto-scheduler.
71+
#
72+
# * `num_measure_trials` is the number of measurement trials we can use during the search.
73+
# We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a
74+
# good value for the search to converge. You can do more trials according to your time budget.
75+
# * In addition, we use `RecordToFile` to dump measurement records into a file `matmul.json`.
76+
# The measurement records can be used to query the history best, resume the search,
77+
# and do more analyses later.
78+
# * see :any:`auto_schedule.TuningOptions`: for more parameters
79+
80+
tune_option = auto_scheduler.TuningOptions(
81+
num_measure_trials=10, measure_callbacks=[auto_scheduler.RecordToFile("matmul.json")]
82+
)
83+
84+
######################################################################
85+
# Run the search
86+
# ^^^^^^^^^^^^^^
87+
# Now we get all inputs ready. Pretty simple, isn't it?
88+
# We can kick off the search and let the auto-scheduler do its magic.
89+
# After some measurement trials, it will return the best schedule it found.
90+
91+
sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option)
92+
93+
######################################################################
94+
# We can lower the schedule to see the IR after auto-scheduling.
95+
# The auto-scheduler correctly performs optimizations including multi-level tiling,
96+
# parallelization, vectorization, unrolling and fusion.
97+
98+
print(tvm.lower(sch, args, simple_mode=True))
99+
100+
######################################################################
101+
# Check correctness
102+
# ^^^^^^^^^^^^^^^^^
103+
# We build the binary and check its correctness
104+
105+
func = tvm.build(sch, args)
106+
a_np = np.random.uniform(size=(128, 128)).astype(np.float32)
107+
b_np = np.random.uniform(size=(128, 128)).astype(np.float32)
108+
c_np = np.random.uniform(size=(128, 128)).astype(np.float32)
109+
d_np = a_np.dot(b_np) + c_np
110+
111+
d_tvm = tvm.nd.empty(d_np.shape)
112+
func(tvm.nd.array(a_np), tvm.nd.array(b_np), tvm.nd.array(c_np), d_tvm)
113+
114+
tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-3)
115+
116+
######################################################################
117+
# Using the record file
118+
# ^^^^^^^^^^^^^^^^^^^^^
119+
# During the search, all measuremnt records are dumpped into the record
120+
# file "matmul.json". The measurement records can be used to re-apply search results,
121+
# resume the search, and perform other analyses.
122+
123+
######################################################################
124+
# Here is an example where we load the best schedule from a file,
125+
# print the equivalent python schedule API, and build the binary again.
126+
127+
# Load the measuremnt record for the best schedule
128+
inp, res = auto_scheduler.load_best("matmul.json", task.workload_key)
129+
130+
# Print equivalent python schedule API. This can be used for debugging and
131+
# learning the behavior of the auto-scheduler.
132+
print(task.compute_dag.print_python_code_from_state(inp.state))
133+
134+
# Rebuild the binary. This shows how you can apply the best schedule from a
135+
# log file without reruning the search again.
136+
sch, args = task.compute_dag.apply_steps_from_state(inp.state)
137+
func = tvm.build(sch, args)
138+
139+
######################################################################
140+
# A more complicated example is to resume the search.
141+
# In this case, we need to create the search policy and cost model by ourselves
142+
# and resume the status of search policy and cost model with the log file.
143+
# In the example below we resume the status and do more 5 trials.
144+
145+
146+
def resume_search(task, log_file):
147+
cost_model = auto_scheduler.XGBModel()
148+
cost_model.update_from_file(log_file)
149+
search_policy = auto_scheduler.SketchPolicy(
150+
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
151+
)
152+
tune_option = auto_scheduler.TuningOptions(
153+
num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
154+
)
155+
sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options=tune_option)
156+
157+
158+
# resume_search(task, "matmul.json")
159+
160+
######################################################################
161+
# .. note::
162+
# We cannot run the line above because of the conflict between
163+
# python's multiprocessing and tvm's thread pool.
164+
# After running a tvm generated binary (L112), the python's multiprocessing
165+
# library will hang forever.
166+
# You have to make sure that you don't run any tvm generated binaries before
167+
# calling ansor's search. To run the L156 above, you should comment out L112-114.
168+
#
169+
# You should be careful about this problem in your applications.
170+
# There are other workarounds for this problem.
171+
# For example, you can start a new thread/process (with the builtin python library
172+
# threading or multiprocessing) and run the tvm binaries in the new thread/process.
173+
# This provides an isolation and avoids the conflict in the main thread/process.

tutorials/autotvm/README.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
.. _tutorials-autotvm-sec:
22

3-
Auto tuning
4-
-----------
3+
AutoTVM : Template-based Auto Tuning
4+
------------------------------------

0 commit comments

Comments
 (0)