Skip to content

Commit 118f943

Browse files
tqchendhruvaray
authored andcommitted
[REFACTOR][TE] Inline -> te/schedule/operation_inline.h (apache#5386)
Rationale: inline is a transformation used in te to rewrite its internal expressions. It is not a formal IRModule->IRModule transform pass. Also removed the python test as the test is covered by stage.compute_inline.
1 parent a52ab12 commit 118f943

File tree

6 files changed

+68
-84
lines changed

6 files changed

+68
-84
lines changed

include/tvm/tir/ir_pass.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,6 @@ Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);
148148
*/
149149
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);
150150

151-
/*!
152-
* \brief inline all calls of f in stmt.
153-
*
154-
* \param stmt The statement to apply inline optimization.
155-
* \param f The function reference to be inlined
156-
* \param args The arguments variable of the function.
157-
* \param body The definition body of the function.
158-
* \return The result stmt
159-
*
160-
* \note All the passes in this file uses SSA form and outputs SSA form.
161-
*/
162-
Stmt Inline(Stmt stmt,
163-
FunctionRef f,
164-
Array<Var> args,
165-
PrimExpr body);
166-
167151
/*!
168152
* \brief Verify if there is any argument bound to compact buffer.
169153
*

src/tir/pass/inline.cc renamed to src/te/schedule/operation_inline.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,31 @@
1818
*/
1919

2020
/*!
21-
* \file inline.cc
21+
* \file operation_inline.cc
2222
*/
2323
#include <tvm/tir/expr.h>
2424
#include <tvm/tir/stmt.h>
2525
#include <tvm/tir/ir_pass.h>
2626
#include <tvm/tir/stmt_functor.h>
27+
#include <utility>
28+
#include "operation_inline.h"
2729

2830
namespace tvm {
29-
namespace tir {
31+
namespace te {
3032

3133
// inliner to inline a function
3234
// the result may not be SSA,
3335
// ConvertSSA need to be applied after this pass
34-
class IRInline final : public StmtExprMutator {
36+
class OperationInliner final : public StmtExprMutator {
3537
public:
36-
IRInline(FunctionRef f, Array<Var> args, PrimExpr body)
37-
: f_(f), args_(args), body_(body) {}
38+
OperationInliner(Operation op, Array<Var> args, PrimExpr body)
39+
: operation_(op), args_(args), body_(body) {}
3840

3941
PrimExpr VisitExpr_(const CallNode* op) final {
4042
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
4143
op = expr.as<CallNode>();
4244

43-
if (op->func == f_) {
45+
if (op->func.same_as(operation_)) {
4446
CHECK_EQ(op->value_index, 0);
4547
expr = body_;
4648
CHECK_EQ(args_.size(), op->args.size());
@@ -68,20 +70,20 @@ class IRInline final : public StmtExprMutator {
6870
}
6971

7072
private:
71-
FunctionRef f_;
73+
Operation operation_;
7274
Array<Var> args_;
7375
PrimExpr body_;
7476
};
7577

7678
Stmt Inline(Stmt stmt,
77-
FunctionRef f,
79+
Operation f,
7880
Array<Var> args,
7981
PrimExpr body) {
8082
CHECK_EQ(f->num_outputs(), 1)
8183
<< "can only inline output single value operation";
82-
Stmt ret = IRInline(f, args, body)(std::move(stmt));
84+
Stmt ret = OperationInliner(f, args, body)(std::move(stmt));
8385
if (ret.same_as(stmt)) return ret;
8486
return ConvertSSA(ret);
8587
}
86-
} // namespace tir
88+
} // namespace te
8789
} // namespace tvm

src/te/schedule/operation_inline.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* \file operation_inline.h
21+
*/
22+
#ifndef TVM_TE_SCHEDULE_OPERATION_INLINE_H_
23+
#define TVM_TE_SCHEDULE_OPERATION_INLINE_H_
24+
25+
#include <tvm/tir/expr.h>
26+
#include <tvm/tir/stmt.h>
27+
#include <tvm/te/operation.h>
28+
#include <tvm/te/tensor.h>
29+
30+
namespace tvm {
31+
namespace te {
32+
33+
/*!
34+
* \brief inline all calls of f in stmt.
35+
*
36+
* \param stmt The statement to apply inline optimization.
37+
* \param op The op to be inlined.
38+
* \param args The arguments variable of the function.
39+
* \param body The definition body of the function.
40+
* \return The result stmt
41+
*
42+
* \note All the passes in this file uses SSA form and outputs SSA form.
43+
*/
44+
Stmt Inline(Stmt stmt,
45+
Operation op,
46+
Array<Var> args,
47+
PrimExpr body);
48+
49+
} // namespace te
50+
} // namespace tvm
51+
#endif // TVM_TE_SCHEDULE_OPERATION_INLINE_H_

src/te/schedule/schedule_dataflow_rewrite.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include <tvm/tir/ir_pass.h>
2727
#include <unordered_set>
2828
#include "message_passing.h"
29+
#include "operation_inline.h"
30+
2931
#include "../../tir/pass/ir_util.h"
3032
#include "../../arith/compute_expr.h"
3133

@@ -583,7 +585,7 @@ void InjectInline(ScheduleNode* sch) {
583585
<< "The Reduce inputs of ComputeOp should "
584586
<< "have the same attribute except value_index";
585587
}
586-
PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][0]),
588+
PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][0]),
587589
stage->op, args, body).as<tir::EvaluateNode>()->value;
588590
if (!new_value.same_as(new_body[j][0])) {
589591
changed[j] = true;
@@ -599,7 +601,7 @@ void InjectInline(ScheduleNode* sch) {
599601
}
600602
} else {
601603
for (size_t k = 0; k < new_body[j].size(); ++k) {
602-
PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][k]),
604+
PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][k]),
603605
stage->op, args, body).as<tir::EvaluateNode>()->value;
604606
if (!new_value.same_as(new_body[j][k])) {
605607
new_body[j].Set(k, new_value);
@@ -611,7 +613,7 @@ void InjectInline(ScheduleNode* sch) {
611613
if (!new_hybrid_body[j].defined()) {
612614
new_hybrid_body[j] = hybrid->body;
613615
}
614-
Stmt new_stmt = tir::Inline(new_hybrid_body[j], stage->op, args, body);
616+
Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body);
615617
if (!new_stmt.same_as(new_hybrid_body[j])) {
616618
new_hybrid_body[j] = new_stmt;
617619
hybrid_changed[j] = true;

src/tir/pass/ffi_api.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
9797

9898
REGISTER_PASS(ConvertSSA);
9999
REGISTER_PASS(VerifySSA);
100-
REGISTER_PASS(Inline);
101100
REGISTER_PASS(IRTransform);
102101
REGISTER_PASS(VerifyGPUCode);
103102
REGISTER_PASS(DecorateDeviceScope);

tests/python/unittest/test_tir_pass_inline.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)