Skip to content

Commit 4ea6712

Browse files
jcf94merrymercy
authored andcommitted
Update PreLoadMeasuredStates & Some bug fix (apache#27)
* Add a threading wrapper to fix the test bug * Set default TVM_USE_AUTO_SCHEDULER to false * Update PreLoadMeasuredStates callback
1 parent 18d44b8 commit 4ea6712

File tree

14 files changed

+178
-30
lines changed

14 files changed

+178
-30
lines changed

python/tvm/ansor/auto_schedule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def set_verbose(self, verbose):
8181
def run_callbacks(self, callbacks):
8282
_ffi_api.SearchPolicyRunCallbacks(self, callbacks)
8383

84+
8485
@tvm._ffi.register_object("ansor.MetaTileRewritePolicy")
8586
class MetaTileRewritePolicy(SearchPolicy):
8687
""" The search policy that searches with meta tiling and random rewrite
@@ -231,7 +232,7 @@ def auto_schedule(workload, target=None,
231232
232233
Parameters
233234
----------
234-
workload : Str or SearchTask
235+
workload : Union[SearchTask, str]
235236
236237
target : Target
237238

python/tvm/ansor/relay_integration.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
99.9% copy-paste of implementation by @MerryMercy
2121
2222
"""
23+
import os
24+
os.environ['TVM_USE_AUTO_SCHEDULER'] = 'true'
25+
2326
import threading
2427
import warnings
2528
import tvm
@@ -95,7 +98,7 @@ def init_op_to_schedule_map():
9598
relay.op.nn.batch_matmul: [topi.generic.schedule_batch_matmul],
9699
}
97100

98-
def extract_from_program(mod, params, ops, target, target_host=None):
101+
def extract_from_program(mod, params, target, target_host=None, ops=None):
99102
""" Extract tuning tasks from a relay program.
100103
101104
This function is the single program version of extract_from_multiple_program.
@@ -117,9 +120,9 @@ def extract_from_program(mod, params, ops, target, target_host=None):
117120
-------
118121
workloads: Array of Tuple(wkl_key, target)
119122
"""
120-
return extract_from_multiple_program([mod], [params], ops, target, target_host)
123+
return extract_from_multiple_program([mod], [params], target, target_host, ops)
121124

122-
def extract_from_multiple_program(mods, params, ops, target, target_host=None):
125+
def extract_from_multiple_program(mods, params, target, target_host=None, ops=None):
123126
""" Extract tuning tasks from multiple relay programs.
124127
125128
This function collects tuning tasks by building a list of programs
@@ -148,6 +151,15 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None):
148151

149152
init_op_to_schedule_map()
150153
topi_scheds = []
154+
155+
if not ops:
156+
ops = [relay.op.nn.dense, relay.op.nn.softmax, relay.op.nn.conv2d,
157+
relay.op.nn.conv2d_transpose, relay.op.nn.max_pool2d,
158+
relay.op.nn.avg_pool2d, relay.op.nn.global_max_pool2d,
159+
relay.op.nn.global_avg_pool2d, relay.op.nn.conv3d,
160+
relay.op.nn.adaptive_avg_pool3d, relay.op.nn.batch_matmul,
161+
relay.op.mean]
162+
151163
for op_name in ops:
152164
if op_name in OP_TO_SCHEDULE:
153165
topi_scheds.extend(OP_TO_SCHEDULE[op_name])

python/tvm/ansor/task_scheduler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ def __init__(self,
145145
self.sequential_now_task_begin_ct = 0
146146

147147
def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'):
148+
""" Tune tasks.
149+
150+
Notice: This method does not have return value, make sure to set `LogToFile`
151+
measure callback in `tune_option`.
152+
153+
Parameters
154+
----------
155+
tune_option: TuneOption
156+
157+
search_policy: Str or List[SearchPolicy]
158+
"""
148159
# init members
149160
self.task_cts = [0 for _ in range(len(self.tasks))]
150161
self.task_costs_history = [[] for _ in range(len(self.tasks))]

scripts/tune_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
import tvm
10-
from tvm import _ffi, relay, ansor
10+
from tvm import _ffi, ansor, relay
1111
import tvm.contrib.graph_runtime as runtime
1212
from tvm.contrib.debugger import debug_runtime
1313
from tvm.contrib import util, ndk

src/ansor/search_policy/meta_tile_rewrite_policy.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,6 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode {
103103
SplitFactorizationMemo split_memo_; // Memorize split space for Split
104104
std::mt19937 rand_gen_; // Random generator
105105
int num_measure_per_iter_; // The number of states to measure per iteration
106-
107-
// The array of already measured states.
108-
std::vector<State> measured_states_vector_;
109-
110-
// The throughputs of already measured states
111-
std::vector<float> measured_states_throughputs_;
112106
};
113107
TVM_DEFINE_MUTABLE_OBJECT_REF(MetaTileRewritePolicy, MetaTileRewritePolicyNode);
114108

src/ansor/search_policy/search_policy.cc

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,28 +37,44 @@ TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesNode);
3737
void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) {
3838
LogReader reader = LogReaderNode::make(log_file);
3939
const auto& res = reader->ReadLines(-1);
40-
if (res.first.size()) {
40+
size_t log_size = res.first.size();
41+
CHECK_EQ(log_size, res.second.size());
42+
if (log_size) {
4143
std::vector<State> measured_states;
42-
for (const auto& inp : res.first) {
44+
std::vector<float> measured_throughputs;
45+
for (size_t i = 0; i < log_size; i++) {
46+
const auto& inp = res.first[i];
4347
if (inp->task->workload_key == cur_task_->workload_key &&
4448
inp->task->target->target_name.compare(
4549
cur_task_->target->target_name) == 0) {
4650
State state = cur_task_->compute_dag.GetInitState();
4751
state.CopyOnWrite()->transform_steps = inp->state->transform_steps;
4852
state.DoSteps(inp->state->transform_steps, cur_task_->compute_dag);
49-
measured_states.push_back(std::move(state));
53+
measured_states.emplace_back(std::move(state));
54+
measured_throughputs.push_back(res.second[i]->error_no == 0 ?
55+
(1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0);
5056
}
5157
}
5258
cur_task_->compute_dag.InferBound(&measured_states);
53-
for (auto state : measured_states) {
54-
measured_states_set_.insert(state.ToStr());
59+
for (size_t i = 0; i < measured_states.size(); i ++) {
60+
auto& state = measured_states[i];
61+
const auto& state_str = state.ToStr();
62+
if (!measured_states_set_.count(state_str)) {
63+
measured_states_set_.insert(state_str);
64+
if (measured_throughputs[i] != 0.0) {
65+
measured_states_vector_.emplace_back(std::move(state));
66+
measured_states_throughputs_.emplace_back(measured_throughputs[i]);
67+
}
68+
}
5569
}
5670

5771
StdCout(verbose_) << "Measured States Set: " << measured_states_set_.size()
58-
<< " state hashes loaded from " << log_file << std::endl;
72+
<< " state hashes loaded from " << log_file
73+
<< " for " << cur_task_->workload_key << std::endl;
5974
} else {
6075
StdCout(verbose_) << "Measured States Set: no states found from "
61-
<< log_file << std::endl;
76+
<< log_file << " for " << cur_task_->workload_key
77+
<< std::endl;
6278
}
6379
}
6480

src/ansor/search_policy/search_policy.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ class SearchPolicyNode : public Object {
101101
// The set of the already measured states.
102102
// We store the string format for redundancy check
103103
std::unordered_set<std::string> measured_states_set_;
104+
// The array of already measured states.
105+
std::vector<State> measured_states_vector_;
106+
// The throughputs of already measured states
107+
std::vector<float> measured_states_throughputs_;
104108
};
105109
TVM_DEFINE_MUTABLE_OBJECT_REF(SearchPolicy, SearchPolicyNode);
106110

src/ansor/search_policy/utils.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,10 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split
311311
CHECK(ps != nullptr);
312312
extent = GetIntImm(ps->extent);
313313
retry_ct += 1;
314-
} while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 && extent == 1);
314+
} while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 &&
315+
(extent == 1 || extent == 0));
315316

316-
if (extent == 1) {
317+
if (extent == 0 || extent == 1) {
317318
return State();
318319
}
319320

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
""" Test Relay Integration """
18+
19+
import tempfile
20+
import numpy as np
21+
22+
import tvm
23+
from tvm import ansor, relay
24+
import tvm.contrib.graph_runtime as runtime
25+
26+
from test_ansor_common import get_tiled_matmul
27+
28+
def dense_graph(N, dtype="float32"):
29+
ori_data = relay.var("data", shape=(N, N), dtype=dtype)
30+
weight = relay.var("weight", shape=(N, N), dtype=dtype)
31+
data = relay.multiply(ori_data, relay.const(2, dtype=dtype))
32+
dense = relay.nn.dense(data, weight, out_dtype=dtype)
33+
dense = relay.add(dense, weight)
34+
dense = relay.nn.dense(dense, weight, out_dtype=dtype)
35+
return ori_data, weight, dense
36+
37+
def test_dense_integration():
38+
N = 128
39+
data, weight, dense = dense_graph(N)
40+
mod = relay.Function([data, weight], dense)
41+
mod = tvm.IRModule.from_expr(mod)
42+
43+
ctx = tvm.context("llvm")
44+
target = tvm.target.create("llvm")
45+
d = tvm.nd.array(np.random.uniform(size=(N, N)).astype(data.type_annotation.dtype), ctx)
46+
w = tvm.nd.array(np.random.uniform(size=(N, N)).astype(weight.type_annotation.dtype), ctx)
47+
workloads, wkl_weights = ansor.extract_from_program(mod, {}, target=target)
48+
49+
assert len(workloads) == 2
50+
assert len(wkl_weights) == 2
51+
52+
tasks = []
53+
for wkl_key in workloads:
54+
dag = ansor.workload_key_to_dag(wkl_key)
55+
tasks.append(ansor.SearchTask(dag, wkl_key, target))
56+
57+
assert str(tasks[0].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \
58+
"placeholder = PLACEHOLDER [128, 128]\n" + \
59+
"compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \
60+
"compute(y, x) += compute[y, x, kk]\n"
61+
62+
assert str(tasks[1].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \
63+
"placeholder = PLACEHOLDER [128, 128]\n" + \
64+
"compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \
65+
"compute(y, x) += compute[y, x, kk]\n" + \
66+
"T_add(ax0, ax1) = (compute[ax0, ax1] + placeholder[ax0, ax1])\n"
67+
68+
tuner = ansor.SimpleTaskScheduler(tasks)
69+
measure_ctx = ansor.LocalRPCMeasureContext()
70+
with tempfile.NamedTemporaryFile() as fp:
71+
tuner.tune(ansor.TuneOption(n_trials=4, runner=measure_ctx.runner,
72+
measure_callbacks=[ansor.LogToFile(fp.name)]))
73+
with ansor.apply_history_best(fp.name):
74+
with relay.build_config(opt_level=3):
75+
graph, lib, opt_params = relay.build_module.build(
76+
mod, target=target)
77+
78+
m = runtime.create(graph, lib, ctx)
79+
m.set_input('data', d)
80+
m.set_input('weight', w)
81+
m.run()
82+
res = m.get_output(0)
83+
if measure_ctx:
84+
del measure_ctx
85+
86+
d = d.asnumpy()
87+
d = d * 2
88+
w = w.asnumpy()
89+
d = np.dot(d, np.transpose(w))
90+
d = d + w
91+
d = np.dot(d, np.transpose(w))
92+
93+
tvm.testing.assert_allclose(res.asnumpy(), d, rtol=1e-5)
94+
95+
if __name__ == "__main__":
96+
test_dense_integration()

tests/python/unittest/test_ansor_search_policy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import random
2121
import numpy as np
2222
import tempfile
23+
import threading
2324

2425
import tvm
2526
from tvm import ansor
@@ -73,8 +74,11 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local'
7374

7475

7576
def test_search_basic():
76-
search_common(seed=944563397)
77-
77+
# Ansor search process with local runner has some modification on thread
78+
# binding, wrap this to a subprocess to eliminate the impacts to other tests
79+
t = threading.Thread(target=search_common, kwargs={'seed': 944563397})
80+
t.start()
81+
t.join()
7882

7983
def test_search_xgb_model_rpc_runner():
8084
measure_ctx = ansor.LocalRPCMeasureContext()

0 commit comments

Comments
 (0)