|
20 | 20 | /*! |
21 | 21 | * \file tvm/tir/stmt_functor.h |
22 | 22 | * |
23 | | - * \brief Functors for tir stmts. |
| 23 | + * \brief Functors for tir stmts |
| 24 | + * utility functions to call common functors. |
24 | 25 | */ |
25 | 26 | #ifndef TVM_TIR_STMT_FUNCTOR_H_ |
26 | 27 | #define TVM_TIR_STMT_FUNCTOR_H_ |
27 | 28 |
|
28 | 29 | #include <tvm/node/functor.h> |
| 30 | +#include <tvm/node/container.h> |
29 | 31 | #include <tvm/tir/expr.h> |
30 | 32 | #include <tvm/tir/stmt.h> |
31 | 33 | #include <tvm/tir/expr_functor.h> |
32 | 34 |
|
33 | 35 | #include <utility> |
| 36 | +#include <unordered_map> |
34 | 37 |
|
35 | 38 | namespace tvm { |
36 | 39 | namespace tir { |
@@ -318,33 +321,86 @@ class StmtExprMutator : |
318 | 321 | }; |
319 | 322 |
|
320 | 323 | /*! |
321 | | - * \brief recursively visit the ir in post DFS order node, and transform it |
| 324 | + * \brief recursively visit the ir nodes in post DFS order, and transform it |
322 | 325 | * |
323 | | - * \param node The ir to be transformed. |
| 326 | + * \param stmt The ir to be transformed. |
324 | 327 | * \param preorder The function called in before recursive mutation |
325 | 328 | * If preorder returns None, then the transform will proceed to recursive call. |
326 | 329 | * If preorder returns a not None Stmt/Expr, the transformer will simply return it and |
327 | 330 | * won't do further recursion. |
328 | 331 | * \param postorder The function called after recursive mutation. |
329 | 332 | * The recursive mutation result is passed to postorder for further mutation. |
330 | 333 | * \param only_enable List of runtime::String. |
331 | | - * If it is empty, all IRNode will call preorder/postorder |
332 | | - * If it is not empty, preorder/postorder will only be called |
| 334 | + * If it is null, all IRNode will call preorder/postorder |
| 335 | + * If it is not null, preorder/postorder will only be called |
333 | 336 | * when the IRNode's type key is in the list. |
334 | 337 | */ |
335 | | -TVM_DLL Stmt IRTransform(Stmt node, |
| 338 | +TVM_DLL Stmt IRTransform(Stmt stmt, |
336 | 339 | const runtime::PackedFunc& preorder, |
337 | 340 | const runtime::PackedFunc& postorder, |
338 | | - const Array<runtime::String>& only_enable = {}); |
| 341 | + Optional<Array<String>> only_enable = NullOpt); |
339 | 342 |
|
340 | 343 | /*! |
341 | | - * \brief recursively visit the ir in post DFS order node, apply fvisit |
| 344 | + * \brief Recursively visit the ir in post DFS order node, apply fvisit |
342 | 345 | * Each node is guaranteed to be visited only once. |
343 | 346 | * \param node The ir to be visited. |
344 | 347 | * \param fvisit The visitor function to be applied. |
345 | 348 | */ |
346 | 349 | TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit); |
347 | 350 |
|
| 351 | +/*! |
| 352 | + * \brief Substitute the var specified by vmap. |
| 353 | + * \param stmt The source statement to be substituted |
| 354 | + * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. |
| 355 | + * \return The converted form. |
| 356 | + */ |
| 357 | +TVM_DLL Stmt Substitute(Stmt stmt, |
| 358 | + std::function<Optional<PrimExpr>(const Var& var)> vmap); |
| 359 | + |
| 360 | +/*! |
| 361 | + * \brief Substitute the var specified by vmap. |
| 362 | + * \param expr The source statement to be substituted |
| 363 | + * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. |
| 364 | + * \return The result. |
| 365 | + */ |
| 366 | +TVM_DLL PrimExpr Substitute(PrimExpr expr, |
| 367 | + std::function<Optional<PrimExpr>(const Var& var)> vmap); |
| 368 | + |
| 369 | +/*! |
| 370 | + * \brief Sugar for substitute via a given map. |
| 371 | + * \param input The input to be updated. |
| 372 | + * \param value_map The map of new values. |
| 373 | + * \return The result. |
| 374 | + * \tparam T the input type, can be PrimExpr or Stmt. |
| 375 | + */ |
| 376 | +template<typename T> |
| 377 | +inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) { |
| 378 | + auto vmap = [&](const Var& var) -> Optional<PrimExpr> { |
| 379 | + auto it = value_map.find(var); |
| 380 | + if (it != value_map.end()) return (*it).second; |
| 381 | + return Optional<PrimExpr>(nullptr); |
| 382 | + }; |
| 383 | + return Substitute(std::move(input), vmap); |
| 384 | +} |
| 385 | + |
| 386 | +/*! |
| 387 | + * \brief Sugar for substitute via a given map. |
| 388 | + * \param input The input to be updated. |
| 389 | + * \param value_map The map of new values. |
| 390 | + * \return The result. |
| 391 | + * \tparam T the input type, can be PrimExpr or Stmt. |
| 392 | + */ |
| 393 | +template<typename T> |
| 394 | +inline T Substitute(T input, |
| 395 | + const std::unordered_map<const VarNode*, PrimExpr>& value_map) { |
| 396 | + auto vmap = [&](const Var& var) -> Optional<PrimExpr> { |
| 397 | + auto it = value_map.find(var.get()); |
| 398 | + if (it != value_map.end()) return (*it).second; |
| 399 | + return Optional<PrimExpr>(nullptr); |
| 400 | + }; |
| 401 | + return Substitute(std::move(input), vmap); |
| 402 | +} |
| 403 | + |
348 | 404 | } // namespace tir |
349 | 405 | } // namespace tvm |
350 | 406 |
|
|
0 commit comments