Skip to content

Commit 45a573e

Browse files
committed
[Graph Runtime] Run_individual for benchmarking individual layers
1 parent 8b1d07f commit 45a573e

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

python/tvm/contrib/debugger/debug_runtime.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(self, module, ctx, graph_json_str, dump_root):
8989
self._dump_path = None
9090
self._debug_run = module["debug_run"]
9191
self._get_output_by_layer = module["get_output_by_layer"]
92+
self._run_individual = module["run_individual"]
9293
graph_runtime.GraphModule.__init__(self, module)
9394
self._create_debug_env(graph_json_str, ctx)
9495

@@ -222,6 +223,9 @@ def run(self, **input_dict):
222223
# Step 3. Display the collected information
223224
self.debug_datum.display_debug_result()
224225

226+
def run_individual(self, number, repeat=1, min_repeat_ms=0):
227+
self._run_individual(number, repeat, min_repeat_ms)
228+
225229
def exit(self):
226230
"""Exits the dump folder and all its contents"""
227231
self._remove_dump_root()

src/runtime/graph/debug/graph_runtime_debug.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,66 @@ class GraphRuntimeDebug : public GraphRuntime {
3838
return time;
3939
}
4040

41+
/*!
42+
* \brief Run each operation in the graph and print out the runtime per op.
43+
* \param number The number of times to run this function for taking average.
44+
* \param repeat The number of times to repeat the measurement.
45+
In total, the function will be invoked (1 + number x repeat) times,
46+
where the first one is warm up and will be discarded in case
47+
there is lazy initialization..
48+
* \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
49+
By default, one `repeat` contains `number` runs. If this parameter is set,
50+
the parameters `number` will be dynamically adjusted to meet the
51+
minimum duration requirement of one `repeat`.
52+
*/
53+
void RunIndividual(int number, int repeat, int min_repeat_ms) {
54+
// warmup run
55+
GraphRuntime::Run();
56+
for (int i = 0; i < repeat; ++i) {
57+
std::chrono::time_point<
58+
std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
59+
60+
double duration_ms = 0.0;
61+
62+
std::vector<double> time_per_op(op_execs_.size(), 0);
63+
do {
64+
std::fill(time_per_op.begin(), time_per_op.end(), 0);
65+
if (duration_ms > 0.0) {
66+
number = static_cast<int>(
67+
std::max((min_repeat_ms / (duration_ms / number) + 1),
68+
number * 1.618)); // 1.618 is chosen by random
69+
}
70+
tbegin = std::chrono::high_resolution_clock::now();
71+
for (int k = 0; k < number; k++) {
72+
for (size_t index = 0; index < op_execs_.size(); ++index) {
73+
if (op_execs_[index]) {
74+
const TVMContext& ctx = data_entry_[entry_id(index, 0)]->ctx;
75+
auto op_tbegin = std::chrono::high_resolution_clock::now();
76+
op_execs_[index]();
77+
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
78+
auto op_tend = std::chrono::high_resolution_clock::now();
79+
double op_duration = std::chrono::duration_cast<
80+
std::chrono::duration<double> >(op_tend - op_tbegin).count();
81+
time_per_op[index] += op_duration * 1000; // ms
82+
}
83+
}
84+
}
85+
tend = std::chrono::high_resolution_clock::now();
86+
duration_ms = std::chrono::duration_cast<std::chrono::duration<double> >
87+
(tend - tbegin).count() * 1000;
88+
} while (duration_ms < min_repeat_ms);
89+
90+
LOG(INFO) << "Repeat: " << i;
91+
int op = 0;
92+
for (size_t index = 0; index < time_per_op.size(); index++) {
93+
if (op_execs_[index]) {
94+
time_per_op[index] /= number;
95+
LOG(INFO) << "Op #" << op++ << ": " << time_per_op[index] << " ms/iter";
96+
}
97+
}
98+
}
99+
}
100+
41101
/*!
42102
* \brief Run each operation and get the output.
43103
* \param index The index of op which needs to be returned.
@@ -119,6 +179,16 @@ PackedFunc GraphRuntimeDebug::GetFunction(
119179
this->DebugGetNodeOutput(args[0], args[1]);
120180
}
121181
});
182+
} else if (name == "run_individual") {
183+
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
184+
int number = args[0];
185+
int repeat = args[1];
186+
int min_repeat_ms = args[2];
187+
CHECK_GT(number, 0);
188+
CHECK_GT(repeat, 0);
189+
CHECK_GE(min_repeat_ms, 0);
190+
this->RunIndividual(number, repeat, min_repeat_ms);
191+
});
122192
} else {
123193
return GraphRuntime::GetFunction(name, sptr_to_self);
124194
}

tests/python/unittest/test_runtime_graph_debug.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def check_verify():
6868
out = mod.get_output(0, tvm.nd.empty((n,)))
6969
np.testing.assert_equal(out.asnumpy(), a + 1)
7070

71+
#test individual run
72+
mod.run_individual(20, 2, 1)
73+
7174
mod.exit()
7275
#verify dump root delete after cleanup
7376
assert(not os.path.exists(directory))
@@ -94,6 +97,7 @@ def check_remote():
9497
mod.run(x=tvm.nd.array(a, ctx))
9598
out = tvm.nd.empty((n,), ctx=ctx)
9699
out = mod.get_output(0, out)
100+
mod.run_individual(20, 2, 1)
97101
np.testing.assert_equal(out.asnumpy(), a + 1)
98102

99103
check_verify()

0 commit comments

Comments
 (0)