|
| 1 | +/*! |
| 2 | + * Copyright (c) 2016 by Contributors |
| 3 | + * \file module.h |
| 4 | + * \brief Low level IR module, |
| 5 | + * Contains lowered function information. |
| 6 | + */ |
| 7 | +#ifndef TVM_MODULE_H_ |
| 8 | +#define TVM_MODULE_H_ |
| 9 | + |
| 10 | +#include <tvm/container.h> |
| 11 | +#include <ir/FunctionBase.h> |
| 12 | +#include <string> |
| 13 | + |
| 14 | +#include "./base.h" |
| 15 | +#include "./expr.h" |
| 16 | +#include "./tensor.h" |
| 17 | + |
| 18 | +namespace tvm { |
| 19 | + |
| 20 | +// Internal node container of lowered function. |
| 21 | +class LoweredFuncNode; |
| 22 | + |
| 23 | +// Internal node container of module. |
| 24 | +class ModuleNode; |
| 25 | + |
| 26 | +/*! |
| 27 | + * \brief LoweredFunc represents function after lowering. |
| 28 | + * This is the final IR representation before codegen. |
| 29 | + */ |
| 30 | +class LoweredFunc : public FunctionRef { |
| 31 | + public: |
| 32 | + LoweredFunc() {} |
| 33 | + explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {} |
| 34 | + /*! |
| 35 | + * \brief access the internal node container |
| 36 | + * \return the pointer to the internal node container |
| 37 | + */ |
| 38 | + inline const LoweredFuncNode* operator->() const; |
| 39 | + /*! \brief specify container node */ |
| 40 | + using ContainerType = LoweredFuncNode; |
| 41 | +}; |
| 42 | + |
| 43 | +/*! \brief Node container of LoweredFunc */ |
| 44 | +class LoweredFuncNode : public FunctionBaseNode { |
| 45 | + public: |
| 46 | + /*! \brief The name of the function */ |
| 47 | + std::string name; |
| 48 | + /*! |
| 49 | + * \brief The arguments of the function |
| 50 | + * This function can only take pod type(int, float) and void* as arguments. |
| 51 | + */ |
| 52 | + Array<Var> args; |
| 53 | + /*! |
| 54 | + * \brief The IterVar axis of threads |
| 55 | + * Each axis need host function to specify a size. |
| 56 | + * \note Calling convention into LoweredFunc |
| 57 | + * |
| 58 | + * Assume we have a LoweredFunc f, a call into f |
| 59 | + * Call(f, arg1, arg2, ..., arg_n, |
| 60 | + * size_axis_1, size_axis_2, ... size_axis_m) |
| 61 | + * |
| 62 | + * Here n = len(args), m = len(thread_axis) |
| 63 | + * |
| 64 | + * The CodeGen should take this and translate this call |
| 65 | + * to corresponding API specific kernel launchs or function calls. |
| 66 | + */ |
| 67 | + Array<IterVar> thread_axis; |
| 68 | + /*! |
| 69 | + * \brief The hint data type of Var handles defined in LetStmt |
| 70 | + * Can be used as hint when generating type signiture. |
| 71 | + * The creation rule is given by |
| 72 | + * handle_data_type[var_handle] = make_const(the_type, 0); |
| 73 | + * |
| 74 | + * \note Expr is used instead Type, because Type cannot be hold by Map. |
| 75 | + * constant Expr of given type is used. |
| 76 | + */ |
| 77 | + Map<Var, Expr> handle_data_type; |
| 78 | + /*! \brief The body statment of the function */ |
| 79 | + Stmt body; |
| 80 | + /*! \return name of the operation */ |
| 81 | + const std::string& func_name() const final { |
| 82 | + return name; |
| 83 | + } |
| 84 | + // there is no return value, but return 1 |
| 85 | + // to enable Call into this function. |
| 86 | + int num_outputs() const final { |
| 87 | + return 1; |
| 88 | + } |
| 89 | + void VisitAttrs(AttrVisitor* v) final { |
| 90 | + v->Visit("name", &name); |
| 91 | + v->Visit("args", &args); |
| 92 | + v->Visit("thread_axis", &thread_axis); |
| 93 | + v->Visit("handle_data_type", &handle_data_type); |
| 94 | + v->Visit("body", &body); |
| 95 | + } |
| 96 | + |
| 97 | + static constexpr const char* _type_key = "LoweredFunc"; |
| 98 | + TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode); |
| 99 | +}; |
| 100 | + |
| 101 | +// Implementations of inline functions |
| 102 | +inline const LoweredFuncNode* LoweredFunc::operator->() const { |
| 103 | + return static_cast<const LoweredFuncNode*>(node_.get()); |
| 104 | +} |
| 105 | + |
| 106 | +} // namespace tvm |
| 107 | + |
| 108 | +#endif // TVM_MODULE_H_ |
0 commit comments