Skip to content

Commit 3fed13c

Browse files
ZihengJiangYuchenJin
authored andcommitted
Reorganize source code. (apache#14)
1 parent 915947c commit 3fed13c

File tree

18 files changed

+438
-242
lines changed

18 files changed

+438
-242
lines changed

include/tvm/relax/builder.h renamed to include/tvm/relax/vm/exec_builder.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,20 @@
1818
*/
1919

2020
/*!
21-
* \file tvm/relax/builder.h
21+
* \file tvm/relax/vm/exec_builder.h
2222
* \brief
2323
*/
24-
#ifndef TVM_RELAX_BUILDER_H_
25-
#define TVM_RELAX_BUILDER_H_
24+
#ifndef TVM_RELAX_EXEC_BUILDER_H_
25+
#define TVM_RELAX_EXEC_BUILDER_H_
2626

2727
#include <tvm/ir/expr.h>
2828
#include <tvm/node/reflection.h>
2929
#include <tvm/node/repr_printer.h>
3030
#include <tvm/runtime/object.h>
3131
#include <tvm/runtime/registry.h>
3232

33-
#include "./vm/bytecode.h"
34-
#include "./vm/executable.h"
33+
#include "./bytecode.h"
34+
#include "./executable.h"
3535

3636
namespace tvm {
3737
namespace relax {
@@ -102,4 +102,4 @@ class ExecBuilder : public ObjectRef {
102102
} // namespace relax
103103
} // namespace tvm
104104

105-
#endif // TVM_RELAX_BUILDER_H_
105+
#endif // TVM_RELAX_EXEC_BUILDER_H_

python/tvm/relax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from . import op
2424
from . import parser
2525
from . import analysis
26+
from . import transform
2627

2728

2829
# Expr

python/tvm/relax/analysis/analysis.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,3 @@ def post_order_visit(expr, fvisit):
3737
The visitor function to be applied.
3838
"""
3939
return _ffi_api.post_order_visit(expr, fvisit)
40-
41-
def fma_rewrite(expr):
42-
"""Perform fused multiply add rewriting in dataflow blocks.
43-
44-
Parameters
45-
----------
46-
expr : tvm.relay.Expr
47-
The input expression.
48-
"""
49-
return _ffi_api.fma_rewrite(expr)
50-
51-
def explicit_memory_rewrite(expr):
52-
"""Perform explicit memory allocation for call_dps in dataflow blocks.
53-
54-
Parameters
55-
----------
56-
expr : tvm.relay.Expr
57-
The input expression.
58-
"""
59-
return _ffi_api.explicit_memory_rewrite(expr)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=wildcard-import, redefined-builtin
18+
"""Relax IR analysis. """
19+
20+
from .transform import *
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
import tvm._ffi
17+
18+
tvm._ffi._init_api("relax.transform", __name__)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=no-else-return
18+
# pylint: disable=unidiomatic-typecheck
19+
from . import _ffi_api
20+
21+
def fma_rewrite(expr):
22+
"""Perform fused multiply add rewriting in dataflow blocks.
23+
24+
Parameters
25+
----------
26+
expr : tvm.relay.Expr
27+
The input expression.
28+
"""
29+
return _ffi_api.fma_rewrite(expr)
30+
31+
def explicit_memory_rewrite(expr):
32+
"""Perform explicit memory allocation for call_dps in dataflow blocks.
33+
34+
Parameters
35+
----------
36+
expr : tvm.relay.Expr
37+
The input expression.
38+
"""
39+
return _ffi_api.explicit_memory_rewrite(expr)
File renamed without changes.

src/relax/expr_functor.cc renamed to src/relax/ir/expr_functor.cc

Lines changed: 1 addition & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
*/
1919

2020
/*!
21-
* \file src/relay/expr_functor.cc
21+
* \file src/relax/expr_functor.cc
2222
* \brief A wrapper around ExprFunctor which functionally updates the AST.
2323
*
2424
* ExprMutator uses memoization and self return in order to amortize
@@ -29,10 +29,6 @@
2929
#include <tvm/relay/analysis.h>
3030
#include <tvm/relay/pattern_functor.h>
3131
#include <tvm/relax/type.h>
32-
#include <stack>
33-
#include <tvm/tir/op.h>
34-
35-
#include "../relay/transforms/pattern_utils.h"
3632

3733
namespace tvm {
3834
namespace relax {
@@ -415,114 +411,5 @@ Expr DataflowMutator::LookupVar(Var var) {
415411
return irbuilder_->LookupVar(var);
416412
}
417413
}
418-
419-
420-
// ==================
421-
// EwiseFMARewriter
422-
// Example:
423-
// x0 = mul(a, b)
424-
// z0 = add(x0, c)
425-
// -->
426-
// z0 = ewise_fma(a, b, c)
427-
428-
// Example 2:
429-
// Question: do we want to support this?
430-
// x0 = mul(a, add(k, b))
431-
// z0 = add(x0, c)
432-
// -->
433-
// lv0 = add(k, b)
434-
// z0 = ewise_fma(a, lv0, c)
435-
436-
class EwiseFMARewriter : public DataflowMutator {
437-
Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override {
438-
static const Op& add_op = Op::Get("relax.add");
439-
static const Op& multiply_op = Op::Get("relax.multiply");
440-
static const Op& ewise_fma_op = Op::Get("relax.ewise_fma");
441-
442-
// TODO: shape & dtype check
443-
const CallNode* op1 = binding->value.as<CallNode>();
444-
if (op1 && (op1->op == add_op)) {
445-
Expr value = LookupVar(Downcast<Var>(op1->args[0]));
446-
const CallNode* op2 = value.as<CallNode>();
447-
if (op2 && op2->op == multiply_op) {
448-
Call fma_call = Call(ewise_fma_op, {op2->args[0], op2->args[1], op1->args[1]}, {}, {});
449-
return ir_builder->Emit(binding->var, fma_call);
450-
}
451-
}
452-
return ir_builder->Emit(binding);
453-
}
454-
};
455-
456-
Expr FMARewrite(const Expr& e) {
457-
return EwiseFMARewriter().Mutate(e);
458-
}
459-
460-
TVM_REGISTER_GLOBAL("relax.analysis.fma_rewrite")
461-
.set_body_typed([](Expr expr) {
462-
return FMARewrite(expr);
463-
});
464-
465-
// ==================
466-
// ExplicitMemMutator
467-
// Example:
468-
// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x))
469-
// -->
470-
// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
471-
// rx.call_packed(op.identity, x, lv0)
472-
473-
class ExplicitMemMutator : public DataflowMutator {
474-
Expr ComputeStorageSize(const Expr& shape, const Type& type) const {
475-
DynTensorType tensor_type = Downcast<DynTensorType>(type);
476-
DataType dtype = DataType(tensor_type->dtype);
477-
// Question: what if the dtype of tensor_type is unknown?
478-
// Symbolic/static shape case
479-
if (auto* shape_expr = shape.as<ShapeExprNode>()) {
480-
PrimExpr num = PrimExpr(dtype.bits()) * PrimExpr(dtype.lanes());
481-
PrimExpr add = num + 7;
482-
PrimExpr ret = 1;
483-
for (PrimExpr dim : shape_expr->values) {
484-
ret = ret * dim;
485-
}
486-
ret = ret * (add / PrimExpr(8));
487-
return ShapeExpr({ret});
488-
}
489-
// Fully dynamic shape case
490-
// will need to dedup with ComputeStorageInRelay when we upstream
491-
Expr prod = relay::Prod(shape, Array<Integer>(nullptr), false, false);
492-
Expr num = relay::MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes());
493-
Expr add = relay::Add(num, relay::MakeConstantScalar(DataType::Int(64), 7));
494-
Expr div = relay::MakeConstantScalar(DataType::Int(64), 8);
495-
Expr ret = relay::Multiply(prod, relay::Divide(add, div));
496-
return ret;
497-
}
498-
499-
Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override {
500-
static const Op& call_dps_op = Op::Get("relax.call_dps");
501-
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
502-
503-
const CallNode* op = binding->value.as<CallNode>();
504-
if(op && op->op == call_dps_op) {
505-
// switch current DataflowBlock to an impure BindingBlock
506-
ir_builder->is_dataflow_ = false;
507-
ShapeExpr output_shape = Downcast<ShapeExpr>(op->args[0]);
508-
Type arg_type = Downcast<Tuple>(op->args[2])->fields[0]->checked_type();
509-
Expr output_size = ComputeStorageSize(output_shape, arg_type);
510-
Var tensor = ir_builder->Emit(Call(alloc_tensor_op, {op->args[0]}));
511-
return ir_builder->Emit(binding->var, Call(op->args[1], {op->args[2], tensor}));
512-
}
513-
return ir_builder->Emit(binding);
514-
}
515-
};
516-
517-
Expr ExplicitMemRewrite(const Expr& e) {
518-
return ExplicitMemMutator().Mutate(e);
519-
}
520-
521-
TVM_REGISTER_GLOBAL("relax.analysis.explicit_memory_rewrite")
522-
.set_body_typed([](Expr expr) {
523-
return ExplicitMemRewrite(expr);
524-
});
525-
526-
527414
} // namespace relax
528415
} // namespace tvm
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)