@@ -422,12 +422,10 @@ Array<LoweredFunc> lower(Schedule sch,
422
422
return Array<LoweredFunc>({ ir::MakeAPI (stmt, name, out_arg_list, 0 , config->restricted_func ) });
423
423
}
424
424
425
- runtime::Module build (const Array<LoweredFunc>& funcs,
426
- const Target& target,
427
- const Target& target_host,
428
- const BuildConfig& config,
429
- Array<LoweredFunc>* fhost_ret,
430
- std::vector<runtime::Module>* devmod_ret) {
425
+ Array<Array<LoweredFunc> > split_dev_host_funcs (const Array<LoweredFunc>& funcs,
426
+ const Target& target,
427
+ const Target& target_host,
428
+ const BuildConfig& config) {
431
429
std::unordered_set<std::string> all_names;
432
430
for (const auto &x : funcs) {
433
431
CHECK (all_names.count (x->name ) == 0 ) << " Duplicate function name " << x->name ;
@@ -466,12 +464,6 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
466
464
}
467
465
}
468
466
469
- if (fhost_ret != nullptr ) {
470
- for (auto f : fhost) {
471
- fhost_ret->push_back (f);
472
- }
473
- }
474
-
475
467
auto keys = target->keys ();
476
468
bool target_is_gpu =
477
469
std::find (keys.begin (), keys.end (), " gpu" ) != keys.end ();
@@ -500,14 +492,25 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
500
492
func = ir::CombineContextCall (func);
501
493
fhost.Set (i, func);
502
494
}
495
+ Array<Array<LoweredFunc> > ret;
496
+ ret.push_back (fhost);
497
+ ret.push_back (fdevice);
498
+ return ret;
499
+ }
500
+
501
+ runtime::Module build (const Array<LoweredFunc>& funcs,
502
+ const Target& target,
503
+ const Target& target_host,
504
+ const BuildConfig& config) {
505
+ auto target_host_val = target_host.defined () ? target_host : DefaultTargetHost (target);
506
+ auto host_dev_funcs = split_dev_host_funcs (funcs, target, target_host, config);
507
+ auto & fhost = host_dev_funcs[0 ];
508
+ auto & fdevice = host_dev_funcs[1 ];
503
509
504
510
auto mhost = codegen::Build (fhost, target_host_val->str ());
505
511
506
512
if (fdevice.size () > 0 ) {
507
513
auto mdev = codegen::Build (fdevice, target->str ());
508
- if (devmod_ret != nullptr ) {
509
- devmod_ret->push_back (mdev);
510
- }
511
514
mhost.Import (mdev);
512
515
}
513
516
0 commit comments