|  | 
|  | 1 | +/* | 
|  | 2 | + * Licensed to the Apache Software Foundation (ASF) under one | 
|  | 3 | + * or more contributor license agreements.  See the NOTICE file | 
|  | 4 | + * distributed with this work for additional information | 
|  | 5 | + * regarding copyright ownership. The ASF licenses this file | 
|  | 6 | + * to you under the Apache License, Version 2.0 (the | 
|  | 7 | + * "License"); you may not use this file except in compliance | 
|  | 8 | + * with the License.  You may obtain a copy of the License at | 
|  | 9 | + * | 
|  | 10 | + *   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 11 | + * | 
|  | 12 | + * Unless required by applicable law or agreed to in writing, | 
|  | 13 | + * software distributed under the License is distributed on an | 
|  | 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
|  | 15 | + * KIND, either express or implied.  See the License for the | 
|  | 16 | + * specific language governing permissions and limitations | 
|  | 17 | + * under the License. | 
|  | 18 | + */ | 
|  | 19 | + | 
|  | 20 | +/*! | 
|  | 21 | + * \file split_host_device.cc | 
|  | 22 | + * \brief Split device function from host. | 
|  | 23 | + */ | 
|  | 24 | +#include <tvm/ffi/function.h> | 
|  | 25 | +#include <tvm/ffi/reflection/registry.h> | 
|  | 26 | +#include <tvm/ir/global_var_supply.h> | 
|  | 27 | +#include <tvm/ir/transform.h> | 
|  | 28 | +#include <tvm/target/target.h> | 
|  | 29 | +#include <tvm/tir/analysis.h> | 
|  | 30 | +#include <tvm/tir/builtin.h> | 
|  | 31 | +#include <tvm/tir/expr.h> | 
|  | 32 | +#include <tvm/tir/op.h> | 
|  | 33 | +#include <tvm/tir/stmt_functor.h> | 
|  | 34 | +#include <tvm/tir/transform.h> | 
|  | 35 | + | 
|  | 36 | +#include "tir/analysis/var_use_def_analysis.h" | 
|  | 37 | + | 
|  | 38 | +namespace tvm { | 
|  | 39 | +namespace tl { | 
|  | 40 | + | 
|  | 41 | +namespace tir = tvm::tir; | 
|  | 42 | + | 
|  | 43 | +class HostDeviceSplitter : public tir::StmtMutator { | 
|  | 44 | + public: | 
|  | 45 | +  explicit HostDeviceSplitter(IRModule* device_mod, std::function<GlobalVar()> var_supply) | 
|  | 46 | +      : device_mod_(device_mod), var_supply_(std::move(var_supply)) {} | 
|  | 47 | + | 
|  | 48 | +  tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) final { | 
|  | 49 | +    if (op->attr_key == tvm::attr::kTarget) { | 
|  | 50 | +      found_device_region_ = true; | 
|  | 51 | +      auto device_target = op->node.as<tvm::Target>().value().WithoutHost(); | 
|  | 52 | +      return SplitDeviceFunc(op->body, device_target); | 
|  | 53 | +    } | 
|  | 54 | +    return tir::StmtMutator::VisitStmt_(op); | 
|  | 55 | +  } | 
|  | 56 | + | 
|  | 57 | +  tir::Stmt ForceSplit(tir::Stmt body, tvm::Target device_target) { | 
|  | 58 | +    return SplitDeviceFunc(std::move(body), std::move(device_target)); | 
|  | 59 | +  } | 
|  | 60 | + | 
|  | 61 | +  bool found_device_region() const { return found_device_region_; } | 
|  | 62 | + | 
|  | 63 | + private: | 
|  | 64 | +  bool found_device_region_{false}; | 
|  | 65 | + | 
|  | 66 | +  tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) { | 
|  | 67 | +    auto [params, buffers_to_declare] = | 
|  | 68 | +        [&]() -> std::tuple<Array<tir::Var>, Array<tir::Buffer>> { | 
|  | 69 | +      tir::VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/true); | 
|  | 70 | +      use_def(body); | 
|  | 71 | + | 
|  | 72 | +      // Sort first by variable type, then by variable name | 
|  | 73 | +      std::vector<tir::Var> params{use_def.undefined_.begin(), use_def.undefined_.end()}; | 
|  | 74 | +      std::sort(params.begin(), params.end(), [](const tir::Var& a, const tir::Var& b) { | 
|  | 75 | +        auto sort_key = [](const tir::Var& var) { | 
|  | 76 | +          return std::tuple{ | 
|  | 77 | +              !var->dtype.is_handle(), | 
|  | 78 | +              var->name_hint, | 
|  | 79 | +          }; | 
|  | 80 | +        }; | 
|  | 81 | +        return sort_key(a) < sort_key(b); | 
|  | 82 | +      }); | 
|  | 83 | +      return {params, use_def.undefined_buffers_}; | 
|  | 84 | +    }(); | 
|  | 85 | + | 
|  | 86 | +    // CodeGenCPU is used for some device-side targets, such as | 
|  | 87 | +    // "ext_dev", and expects to be able to return a int32_t status | 
|  | 88 | +    // code. | 
|  | 89 | + | 
|  | 90 | +    bool can_propagate_errors = [&]() { | 
|  | 91 | +      auto kind = device_target->GetTargetDeviceType(); | 
|  | 92 | +      return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon; | 
|  | 93 | +    }(); | 
|  | 94 | +    IntImm success(DataType::Int(32), 0); | 
|  | 95 | +    Type kernel_ret_type; | 
|  | 96 | +    if (can_propagate_errors) { | 
|  | 97 | +      kernel_ret_type = PrimType(DataType::Int(32)); | 
|  | 98 | +      body = tir::SeqStmt::Flatten(body, tir::Evaluate(ret(success))); | 
|  | 99 | +    } else { | 
|  | 100 | +      kernel_ret_type = VoidType(); | 
|  | 101 | +    } | 
|  | 102 | + | 
|  | 103 | +    for (tir::Buffer buf : buffers_to_declare) { | 
|  | 104 | +      body = tir::DeclBuffer(buf, std::move(body)); | 
|  | 105 | +    } | 
|  | 106 | +    tir::PrimFunc device_func(params, body, kernel_ret_type); | 
|  | 107 | +    device_func = WithAttrs( | 
|  | 108 | +        std::move(device_func), | 
|  | 109 | +        {{tvm::attr::kTarget, device_target}, | 
|  | 110 | +         {tir::attr::kNoAlias, true}, | 
|  | 111 | +         {tir::attr::kIsGlobalFunc, true}}); | 
|  | 112 | + | 
|  | 113 | +    GlobalVar kernel_symbol_global = var_supply_(); | 
|  | 114 | +    (*device_mod_)->Add(kernel_symbol_global, device_func); | 
|  | 115 | +    Array<PrimExpr> args = params.Map([](const tir::Var& var) -> PrimExpr { return var; }); | 
|  | 116 | + | 
|  | 117 | +    if (can_propagate_errors) { | 
|  | 118 | +      tir::Var kernel_error_code("kernel_error_code", success->dtype); | 
|  | 119 | +      tir::Call kernel_call(success->dtype, kernel_symbol_global, args); | 
|  | 120 | +      tir::AssertStmt assert_success( | 
|  | 121 | +          kernel_error_code == success, tir::StringImm("Error executing compute kernel"), | 
|  | 122 | +          tir::Evaluate(0)); | 
|  | 123 | +      tir::LetStmt let_check(kernel_error_code, kernel_call, assert_success); | 
|  | 124 | + | 
|  | 125 | +      return let_check; | 
|  | 126 | + | 
|  | 127 | +    } else { | 
|  | 128 | +      return tir::Evaluate(tir::Call(DataType::Void(), kernel_symbol_global, args)); | 
|  | 129 | +    } | 
|  | 130 | +  } | 
|  | 131 | + | 
|  | 132 | +  // target ir module | 
|  | 133 | +  IRModule* device_mod_; | 
|  | 134 | +  // Generate new GlobalVar for the kernel | 
|  | 135 | +  std::function<GlobalVar()> var_supply_; | 
|  | 136 | +}; | 
|  | 137 | + | 
|  | 138 | +tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule* device_mod, | 
|  | 139 | +                         std::function<GlobalVar()> var_supply) { | 
|  | 140 | +  HostDeviceSplitter splitter(device_mod, std::move(var_supply)); | 
|  | 141 | + | 
|  | 142 | +  if (auto body = splitter(func->body); !body.same_as(func->body)) { | 
|  | 143 | +    func.CopyOnWrite()->body = body; | 
|  | 144 | +  } else if (!splitter.found_device_region()) { | 
|  | 145 | +    if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) { | 
|  | 146 | +      auto device_target = target.value().WithoutHost(); | 
|  | 147 | +      if (device_target.defined() && func->HasNonzeroAttr(tir::attr::kIsEntryFunc) && | 
|  | 148 | +          tir::is_no_op(func->body)) { | 
|  | 149 | +        if (auto forced = splitter.ForceSplit(func->body, device_target); | 
|  | 150 | +            !forced.same_as(func->body)) { | 
|  | 151 | +          func.CopyOnWrite()->body = forced; | 
|  | 152 | +        } | 
|  | 153 | +      } | 
|  | 154 | +    } | 
|  | 155 | +  } | 
|  | 156 | + | 
|  | 157 | +  return func; | 
|  | 158 | +} | 
|  | 159 | + | 
|  | 160 | +namespace transform { | 
|  | 161 | + | 
|  | 162 | +tvm::transform::Pass SplitHostDevice() { | 
|  | 163 | +  auto pass_func = [](IRModule mod, tvm::transform::PassContext ctx) { | 
|  | 164 | +    tvm::GlobalVarSupply global_var_supply(mod); | 
|  | 165 | + | 
|  | 166 | +    IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({})); | 
|  | 167 | +    IRModule updates = IRModule(Map<GlobalVar, BaseFunc>({})); | 
|  | 168 | + | 
|  | 169 | +    for (const auto& [gvar, base_func] : mod->functions) { | 
|  | 170 | +      if (auto opt = base_func.as<tir::PrimFunc>()) { | 
|  | 171 | +        tir::PrimFunc func = opt.value(); | 
|  | 172 | + | 
|  | 173 | +        auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); | 
|  | 174 | +        auto name_prefix = global_symbol.value_or(gvar->name_hint); | 
|  | 175 | +        auto kernel_name = name_prefix + "_kernel"; | 
|  | 176 | +        auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar { | 
|  | 177 | +          return global_var_supply->FreshGlobal(kernel_name, false); | 
|  | 178 | +        }; | 
|  | 179 | + | 
|  | 180 | +        func = ::tvm::tl::SplitHostDevice(std::move(func), &device_mod, var_supply); | 
|  | 181 | +        if (!func.same_as(base_func)) { | 
|  | 182 | +          updates->Add(gvar, func); | 
|  | 183 | +        } | 
|  | 184 | +      } | 
|  | 185 | +    } | 
|  | 186 | + | 
|  | 187 | +    mod->Update(updates); | 
|  | 188 | +    mod->Update(device_mod); | 
|  | 189 | +    return tir::transform::ConvertSSA()(mod); | 
|  | 190 | +  }; | 
|  | 191 | + | 
|  | 192 | +  return tvm::transform::CreateModulePass(pass_func, 0, "tl.SplitHostDevice", {}); | 
|  | 193 | +} | 
|  | 194 | + | 
|  | 195 | +TVM_FFI_STATIC_INIT_BLOCK({ | 
|  | 196 | +  namespace refl = tvm::ffi::reflection; | 
|  | 197 | +  refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice); | 
|  | 198 | +}); | 
|  | 199 | + | 
|  | 200 | +}  // namespace transform | 
|  | 201 | +}  // namespace tl | 
|  | 202 | +}  // namespace tvm | 
0 commit comments