Skip to content

Commit 3d5798e

Browse files
author
Bing Xu
committed
asdf
1 parent 6ea243a commit 3d5798e

File tree

7 files changed

+468
-266
lines changed

7 files changed

+468
-266
lines changed

include/tvm/build_module.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,23 +344,32 @@ TVM_DLL Array<LoweredFunc> lower(Schedule sch,
344344
const std::string& name,
345345
const std::unordered_map<Tensor, Buffer>& binds,
346346
const BuildConfig& config);
347+
/*!
348+
* \brief Split host/device function and running necessary pass before build
349+
* \param funcs The functions to be built.
350+
* \param target The target device to build for.
351+
* \param target_host The target for building host code. To use the default, pass Target()
352+
* \param config The build configuration.
353+
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
354+
second is device function array
355+
*/
356+
TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
357+
const Target& target,
358+
const Target& target_host,
359+
const BuildConfig& config);
347360

348361
/*!
349362
* \brief Build a device and host module for a specific target from an array of lowered functions.
350363
* \param funcs The functions to be built.
351364
* \param target The target device to build for.
352365
* \param target_host The target for building host code. To use the default, pass Target()
353366
* \param config The build configuration.
354-
* \param (optional) returned host functions
355-
* \param (optional) returned dev mods
356367
* \return The built module.
357368
*/
358369
TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
359370
const Target& target,
360371
const Target& target_host,
361-
const BuildConfig& config,
362-
Array<LoweredFunc>* fhost_ret = nullptr,
363-
std::vector<runtime::Module>* devmod_ret = nullptr);
372+
const BuildConfig& config);
364373

365374
class GenericFuncNode;
366375

src/codegen/build_module.cc

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,10 @@ Array<LoweredFunc> lower(Schedule sch,
422422
return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
423423
}
424424

425-
runtime::Module build(const Array<LoweredFunc>& funcs,
426-
const Target& target,
427-
const Target& target_host,
428-
const BuildConfig& config,
429-
Array<LoweredFunc>* fhost_ret,
430-
std::vector<runtime::Module>* devmod_ret) {
425+
Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
426+
const Target& target,
427+
const Target& target_host,
428+
const BuildConfig& config) {
431429
std::unordered_set<std::string> all_names;
432430
for (const auto &x : funcs) {
433431
CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name;
@@ -466,12 +464,6 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
466464
}
467465
}
468466

469-
if (fhost_ret != nullptr) {
470-
for (auto f : fhost) {
471-
fhost_ret->push_back(f);
472-
}
473-
}
474-
475467
auto keys = target->keys();
476468
bool target_is_gpu =
477469
std::find(keys.begin(), keys.end(), "gpu") != keys.end();
@@ -500,14 +492,25 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
500492
func = ir::CombineContextCall(func);
501493
fhost.Set(i, func);
502494
}
495+
Array<Array<LoweredFunc> > ret;
496+
ret.push_back(fhost);
497+
ret.push_back(fdevice);
498+
return ret;
499+
}
500+
501+
runtime::Module build(const Array<LoweredFunc>& funcs,
502+
const Target& target,
503+
const Target& target_host,
504+
const BuildConfig& config) {
505+
auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target);
506+
auto host_dev_funcs = split_dev_host_funcs(funcs, target, target_host, config);
507+
auto& fhost = host_dev_funcs[0];
508+
auto& fdevice = host_dev_funcs[1];
503509

504510
auto mhost = codegen::Build(fhost, target_host_val->str());
505511

506512
if (fdevice.size() > 0) {
507513
auto mdev = codegen::Build(fdevice, target->str());
508-
if (devmod_ret != nullptr) {
509-
devmod_ret->push_back(mdev);
510-
}
511514
mhost.Import(mdev);
512515
}
513516

0 commit comments

Comments
 (0)