@@ -38,6 +38,66 @@ class GraphRuntimeDebug : public GraphRuntime {
38
38
return time;
39
39
}
40
40
41
+ /* !
42
+ * \brief Run each operation in the graph and print out the runtime per op.
43
+ * \param number The number of times to run this function for taking average.
44
+ * \param repeat The number of times to repeat the measurement.
45
+ In total, the function will be invoked (1 + number x repeat) times,
46
+ where the first one is warm up and will be discarded in case
47
+ there is lazy initialization..
48
+ * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
49
+ By default, one `repeat` contains `number` runs. If this parameter is set,
50
+ the parameters `number` will be dynamically adjusted to meet the
51
+ minimum duration requirement of one `repeat`.
52
+ */
53
+ void RunIndividual (int number, int repeat, int min_repeat_ms) {
54
+ // warmup run
55
+ GraphRuntime::Run ();
56
+ for (int i = 0 ; i < repeat; ++i) {
57
+ std::chrono::time_point<
58
+ std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
59
+
60
+ double duration_ms = 0.0 ;
61
+
62
+ std::vector<double > time_per_op (op_execs_.size (), 0 );
63
+ do {
64
+ std::fill (time_per_op.begin (), time_per_op.end (), 0 );
65
+ if (duration_ms > 0.0 ) {
66
+ number = static_cast <int >(
67
+ std::max ((min_repeat_ms / (duration_ms / number) + 1 ),
68
+ number * 1.618 )); // 1.618 is chosen by random
69
+ }
70
+ tbegin = std::chrono::high_resolution_clock::now ();
71
+ for (int k = 0 ; k < number; k++) {
72
+ for (size_t index = 0 ; index < op_execs_.size (); ++index) {
73
+ if (op_execs_[index]) {
74
+ const TVMContext& ctx = data_entry_[entry_id (index, 0 )]->ctx ;
75
+ auto op_tbegin = std::chrono::high_resolution_clock::now ();
76
+ op_execs_[index]();
77
+ TVMSynchronize (ctx.device_type , ctx.device_id , nullptr );
78
+ auto op_tend = std::chrono::high_resolution_clock::now ();
79
+ double op_duration = std::chrono::duration_cast<
80
+ std::chrono::duration<double > >(op_tend - op_tbegin).count ();
81
+ time_per_op[index] += op_duration * 1000 ; // ms
82
+ }
83
+ }
84
+ }
85
+ tend = std::chrono::high_resolution_clock::now ();
86
+ duration_ms = std::chrono::duration_cast<std::chrono::duration<double > >
87
+ (tend - tbegin).count () * 1000 ;
88
+ } while (duration_ms < min_repeat_ms);
89
+
90
+ LOG (INFO) << " Repeat: " << i;
91
+ int op = 0 ;
92
+ for (size_t index = 0 ; index < time_per_op.size (); index++) {
93
+ if (op_execs_[index]) {
94
+ time_per_op[index] /= number;
95
+ LOG (INFO) << " Op #" << op++ << " : " << time_per_op[index] << " ms/iter" ;
96
+ }
97
+ }
98
+ }
99
+ }
100
+
41
101
/* !
42
102
* \brief Run each operation and get the output.
43
103
* \param index The index of op which needs to be returned.
@@ -119,6 +179,16 @@ PackedFunc GraphRuntimeDebug::GetFunction(
119
179
this ->DebugGetNodeOutput (args[0 ], args[1 ]);
120
180
}
121
181
});
182
+ } else if (name == " run_individual" ) {
183
+ return PackedFunc ([sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) {
184
+ int number = args[0 ];
185
+ int repeat = args[1 ];
186
+ int min_repeat_ms = args[2 ];
187
+ CHECK_GT (number, 0 );
188
+ CHECK_GT (repeat, 0 );
189
+ CHECK_GE (min_repeat_ms, 0 );
190
+ this ->RunIndividual (number, repeat, min_repeat_ms);
191
+ });
122
192
} else {
123
193
return GraphRuntime::GetFunction (name, sptr_to_self);
124
194
}
0 commit comments