Skip to content

Commit 5b81413

Browse files
committed
[Compatible] Fix apache/tvm#12382
1 parent d79d0d9 commit 5b81413

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

src/op/dialect/cutlass/cutlass_fusion.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@ OpEnv* Tune(const op::CallValues& call, OpEnv* op_env) {
7878
} else {
7979
std::vector<std::shared_ptr<TunableConfig>> tunable = env->ListTunableConfigs();
8080
const int number = 10, repeat = 1, min_repeat_ms = 0, cooldown_interval_ms = 0,
81-
repeats_to_cooldow = 1;
81+
repeats_to_cooldown = 1, limit_zero_time_iterations = 100;
8282
double min_time = std::numeric_limits<double>::max();
8383
for (auto& config : tunable) {
8484
env->SetTunableConfig(config);
8585
env->Init(call);
86-
Array<FloatValue> result =
87-
TimeEvaluator(TypedPackedFunc<void()>([&]() { env->Execute(call); }), call->device,
88-
number, repeat, min_repeat_ms, cooldown_interval_ms, repeats_to_cooldow)();
86+
Array<FloatValue> result = TimeEvaluator(
87+
TypedPackedFunc<void()>([&]() { env->Execute(call); }), call->device, number, repeat,
88+
min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown)();
8989
CHECK_EQ(result.size(), 1U);
9090
if (result[0]->value < min_time) {
9191
min_time = result[0]->value;

src/op/dialect/cutlass/timer.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ tvm::runtime::Module MakeCutlassModule(PackedFunc pf) {
2222
}
2323

2424
PackedFunc TimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, int min_repeat_ms,
25-
int cooldown_interval_ms, int repeats_to_cooldown) {
25+
int limit_zero_time_iterations, int cooldown_interval_ms,
26+
int repeats_to_cooldown) {
2627
tvm::Device tvm_dev = dev;
2728
auto wrapper = [=](TVMArgs args, TVMRetValue* rv) mutable {
2829
const static PackedFunc rpv_eval = registry::GetPackedFunc("runtime.RPCTimeEvaluator");
2930
PackedFunc timer =
3031
rpv_eval(MakeCutlassModule(pf), "main", (int)tvm_dev.device_type, (int)tvm_dev.device_id,
31-
number, repeat, min_repeat_ms, cooldown_interval_ms, repeats_to_cooldown, "");
32+
number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms,
33+
repeats_to_cooldown, "");
3234
TVMRetValue timer_rv;
3335
timer.CallPacked(args, &timer_rv);
3436
const double* speed = reinterpret_cast<const double*>(timer_rv.operator std::string().data());

src/op/dialect/cutlass/timer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ tvm::runtime::Module MakeCutlassModule(registry::PackedFunc pf);
6262
* minimum duration requirement of one `repeat`.
6363
* i.e., When the run time of one `repeat` falls below this time, the `number`
6464
* parameter will be automatically increased.
65+
* \param limit_zero_time_iterations The maximum number of repeats when
66+
* measured time is equal to 0. It helps to avoid hanging
67+
* during measurements.
6568
* \param cooldown_interval_ms The cooldown interval in milliseconds between the number of
6669
* repeats defined by `repeats_to_cooldown`.
6770
* \param repeats_to_cooldown The number of repeats before the cooldown is activated.
@@ -70,6 +73,7 @@ tvm::runtime::Module MakeCutlassModule(registry::PackedFunc pf);
7073
*/
7174
registry::PackedFunc TimeEvaluator(registry::PackedFunc pf, Device dev, int number = 10,
7275
int repeat = 1, int min_repeat_ms = 0,
76+
int limit_zero_time_iterations = 100,
7377
int cooldown_interval_ms = 0, int repeats_to_cooldown = 1);
7478

7579
} // namespace cutlass

0 commit comments

Comments
 (0)