|
2 | 2 | * Copyright (c) 2018 by Contributors |
3 | 3 | * \file tvm/relay/pass.h |
4 | 4 | * \brief The set of Relay passes written in C++. |
| 5 | + * |
| 6 | + * This file also implements a pass manager. The pass manager manages a sequence |
| 7 | + * of Relay-to-Relay transformation passes over a particlar unit of AST. The |
| 8 | + * design is largely inspired from LLVM's pass manager and modern deep learning |
| 9 | + * frameworks that perform tensor->tensor transformations. |
| 10 | + * |
| 11 | + * The responsibilities of a traditional compiler pass manager usually involves: |
| 12 | + * - Organizing the execution order of optimization passes though not |
| 13 | + * necessarily in the optimal sequence. |
| 14 | + * - Collecting required analysis information and keep them up-to-date. |
| 15 | + * - Reducing the effort required to implement new passes for compiler |
| 16 | + * developers, etc. |
| 17 | + * |
| 18 | + * Similar to LLVM's pass manager, we designed the Relay pass manager to work |
| 19 | + * different granularity, i.e. module level, function level, and even sequential |
| 20 | + * passe that contains a host of passes. |
| 21 | + * |
| 22 | + * However, we also extend the functionality of the traditional pass manager |
| 23 | + * with the consideration of requirements/convention from deep learning |
| 24 | + * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass |
| 25 | + * manager performs the Relay.Module -> Relay.Module transformation. All |
| 26 | + * different types of passes, including the sequential-level pass object, are |
| 27 | + * essentially pass objects. This design, therefore, effectively provides users |
| 28 | + * a consistent and convenient interface, i.e. Pass, to play with. It offers a |
| 29 | + * means to ease the development and testing of Relay passes. For example, with |
| 30 | + * the pass manager, external users will be able to have custom passes correctly |
| 31 | + * scheduled without having to modify a single handcrafted pass order. |
| 32 | + * |
| 33 | + * In the future we need to describe constraints between passes. For example, |
| 34 | + * we may want to preserve dependencies between different passes and validate |
| 35 | + * them on the completion of a certain pass. |
| 36 | + * |
| 37 | + * We also need to store side information and import the error reporting system. |
5 | 38 | */ |
6 | 39 | #ifndef TVM_RELAY_PASS_H_ |
7 | 40 | #define TVM_RELAY_PASS_H_ |
8 | 41 |
|
| 42 | +#include <tvm/ir.h> |
| 43 | +#include <tvm/packed_func_ext.h> |
| 44 | +#include <tvm/relay/error.h> |
9 | 45 | #include <tvm/relay/expr.h> |
10 | 46 | #include <tvm/relay/module.h> |
11 | 47 | #include <tvm/relay/op_attr_types.h> |
| 48 | +#include <tvm/relay/type.h> |
| 49 | + |
12 | 50 | #include <string> |
| 51 | +#include <vector> |
13 | 52 |
|
14 | 53 | namespace tvm { |
15 | 54 | namespace relay { |
16 | 55 |
|
| 56 | +namespace pass { |
| 57 | + |
| 58 | +/* |
| 59 | + * \brief The context of pass. |
| 60 | + */ |
| 61 | +class PassContext; |
| 62 | + |
| 63 | +/*! |
| 64 | + * \brief PassContextNode contains the information that a pass can rely on, such as |
| 65 | + * analysis results. |
| 66 | + */ |
| 67 | +class PassContextNode : public RelayNode { |
| 68 | + public: |
| 69 | + /*! |
| 70 | + * \brief The error reporter used to notify users why an optimization fails. |
| 71 | + */ |
| 72 | + ErrorReporter err_reporter; |
| 73 | + |
| 74 | + PassContextNode() = default; |
| 75 | + |
| 76 | + void VisitAttrs(tvm::AttrVisitor* v) final { |
| 77 | + } |
| 78 | + |
| 79 | + TVM_DLL static PassContext make(); |
| 80 | + |
| 81 | + static constexpr const char* _type_key = "relay.PassContext"; |
| 82 | + TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); |
| 83 | +}; |
| 84 | + |
| 85 | +TVM_DEFINE_NODE_REF(PassContext, PassContextNode) |
| 86 | + |
| 87 | +/* |
| 88 | + * \brief The meta data of a pass. |
| 89 | + * |
| 90 | + * PassInfo can be extended conveniently in the future if more meta information |
| 91 | + * is needed. |
| 92 | + */ |
| 93 | +class PassInfo; |
| 94 | + |
| 95 | +/*! |
| 96 | + * \brief PassInfoNode contains meta data that will be used to help optimization |
| 97 | + * and analysis. |
| 98 | + */ |
| 99 | +class PassInfoNode : public RelayNode { |
| 100 | + public: |
| 101 | + /*! \brief The minimal optimization level that this pass will be enabled. */ |
| 102 | + int opt_level; |
| 103 | + |
| 104 | + /*! \brief The name of an optimization/analysis pass. */ |
| 105 | + std::string name; |
| 106 | + |
| 107 | + /*! \brief The passes that are required to perform the current pass. */ |
| 108 | + tvm::Array<tvm::Expr> required; |
| 109 | + |
| 110 | + PassInfoNode() = default; |
| 111 | + |
| 112 | + void VisitAttrs(tvm::AttrVisitor* v) final { |
| 113 | + v->Visit("opt_level", &opt_level); |
| 114 | + v->Visit("name", &name); |
| 115 | + v->Visit("required", &required); |
| 116 | + } |
| 117 | + |
| 118 | + TVM_DLL static PassInfo make(int opt_level, std::string name, |
| 119 | + tvm::Array<tvm::Expr> required); |
| 120 | + |
| 121 | + static constexpr const char* _type_key = "relay.PassInfo"; |
| 122 | + TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode); |
| 123 | +}; |
| 124 | + |
| 125 | +TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode) |
| 126 | + |
| 127 | +class Pass; |
| 128 | + |
| 129 | +/*! |
| 130 | + * \brief PassNode is the base type of differnt types of optimization passes. |
| 131 | + * It is designed as a pure class and implemented by different pass subclasses |
| 132 | + * at different granularity of Relay nodes. |
| 133 | + */ |
| 134 | +class PassNode : public RelayNode { |
| 135 | + public: |
| 136 | + /* |
| 137 | + * \brief Get the pass information/meta data. */ |
| 138 | + virtual PassInfo Info() const = 0; |
| 139 | + |
| 140 | + /*! |
| 141 | + * \brief Set the context information for a pass. |
| 142 | + * |
| 143 | + * \param pass_ctx The context information for a certain pass. |
| 144 | + */ |
| 145 | + virtual void SetContext(const PassContext& pass_ctx) = 0; |
| 146 | + |
| 147 | + /*! |
| 148 | + * \brief Execute the optimization pass using a functor. |
| 149 | + * |
| 150 | + * \param mod The module that an optimization pass runs on. |
| 151 | + * |
| 152 | + * \return The updated module. |
| 153 | + */ |
| 154 | + virtual Module operator()(const Module& mod) const = 0; |
| 155 | + |
| 156 | + void VisitAttrs(tvm::AttrVisitor* v) override {} |
| 157 | + |
| 158 | + static constexpr const char* _type_key = "relay.Pass"; |
| 159 | + TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); |
| 160 | +}; |
| 161 | + |
| 162 | +class Pass : public NodeRef { |
| 163 | + public: |
| 164 | + Pass() = default; |
| 165 | + explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {} |
| 166 | + |
| 167 | + PassNode* operator->() const { |
| 168 | + return static_cast<PassNode*>(this->node_.get()); |
| 169 | + } |
| 170 | + |
| 171 | + using ContainerType = PassNode; |
| 172 | +}; |
| 173 | + |
| 174 | +/* |
| 175 | + * \brief Create a module pass. |
| 176 | + * |
| 177 | + * \param pass_func The packed function that contains the optimization. |
| 178 | + * \param opt_level The optimization level of the module pass. |
| 179 | + * \param name The name of the module pass. |
| 180 | + * \param required The list of the passes that the module pass is dependent on. |
| 181 | + * |
| 182 | + * \return The created module pass. |
| 183 | + */ |
| 184 | +Pass CreateModulePass( |
| 185 | + const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func, |
| 186 | + int opt_level, |
| 187 | + const std::string& name, |
| 188 | + const tvm::Array<tvm::Expr>& required); |
| 189 | + |
| 190 | +/* |
| 191 | + * \brief Create a function pass. |
| 192 | + * |
| 193 | + * \param pass_func The packed function that contains the optimization. |
| 194 | + * \param opt_level The optimization level of the function pass. |
| 195 | + * \param name The name of the function pass. |
| 196 | + * \param required The list of the passes that the function pass is dependent on. |
| 197 | + * |
| 198 | + * \return The created function pass. |
| 199 | + */ |
| 200 | +Pass CreateFunctionPass( |
| 201 | + const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func, |
| 202 | + int opt_level, |
| 203 | + const std::string& name, |
| 204 | + const tvm::Array<tvm::Expr>& required); |
| 205 | +/* |
| 206 | + * \brief Create a sequential pass. |
| 207 | + * |
| 208 | + * \param passes The optimization passes will be performed. |
| 209 | + * \param opt_level The optimization level of the sequential pass. |
| 210 | + * \param name The name of the sequential pass. |
| 211 | + * \param required The list of the passes that the sequential pass is dependent on. |
| 212 | + * \param disabled The disabled passes. |
| 213 | + * |
| 214 | + * \return The created sequential pass. |
| 215 | + */ |
| 216 | +Pass CreateSequentialPass(const tvm::Array<Pass>& passes, |
| 217 | + int opt_level, |
| 218 | + const std::string& name, |
| 219 | + const tvm::Array<tvm::Expr>& required, |
| 220 | + const tvm::Array<tvm::Expr>& disabled); |
| 221 | + |
| 222 | +} // namespace pass |
| 223 | + |
17 | 224 | /*! |
18 | 225 | * \brief Infer the type of an expression. |
19 | 226 | * |
|
0 commit comments