-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay][Pass] Add submodule extraction pass (#4960)
* rebased * fix lint
- Loading branch information
1 parent
2e913f0
commit 327891c
Showing
3 changed files
with
219 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file extract_fused_functions.cc | ||
* \brief Apply fusion and extract fused primitive functions from an IRModule | ||
*/ | ||
#include <tvm/relay/analysis.h> | ||
#include <tvm/relay/expr.h> | ||
#include <tvm/relay/expr_functor.h> | ||
#include <tvm/relay/transform.h> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
class FusedFunctionExtractorWrapper : private ExprVisitor { | ||
public: | ||
explicit FusedFunctionExtractorWrapper(const IRModule& mod) : mod_(mod) {} | ||
|
||
IRModule Extract() { | ||
VisitExpr(this->mod_->Lookup("main")); | ||
|
||
auto functions = Map<GlobalVar, BaseFunc>(); | ||
for (auto pair : this->functions) { | ||
functions.Set(GlobalVar(pair.first), pair.second); | ||
} | ||
|
||
this->mod_->functions = functions; | ||
return this->mod_; | ||
} | ||
|
||
private: | ||
const IRModule mod_; | ||
// This is not simply Map<GlobalVar, Function> because GlobalVar doesn't | ||
// have the desired equals property | ||
Map<std::string, Function> functions; | ||
|
||
void VisitExpr_(const FunctionNode* n) final { | ||
if (n->HasNonzeroAttr(attr::kPrimitive)) { | ||
// Add function to functions, keyed by function hash string | ||
Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs); | ||
size_t hash_ = StructuralHash()(func); | ||
this->functions.Set(std::to_string(hash_), func); | ||
} | ||
|
||
ExprVisitor::VisitExpr_(n); | ||
} | ||
}; | ||
|
||
namespace transform { | ||
|
||
Pass ExtractFusedFunctions() { | ||
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = | ||
[=](IRModule m, PassContext pc) { return FusedFunctionExtractorWrapper(m).Extract(); }; | ||
auto fused_function_extractor_pass = CreateModulePass(pass_func, 1, "ExtractFusedFunctions", {}); | ||
|
||
return Sequential({SimplifyInference(), FuseOps(3), fused_function_extractor_pass}, | ||
"ExtractFusedFunctions"); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relay._analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions); | ||
|
||
} // namespace transform | ||
|
||
} // namespace relay | ||
} // namespace tvm |
115 changes: 115 additions & 0 deletions
115
tests/python/relay/test_analysis_extract_fused_functions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Test function extraction""" | ||
import tvm | ||
from tvm import relay | ||
from tvm.relay.testing.resnet import get_workload | ||
|
||
|
||
def get_conv_net(): | ||
"""This gets the net for a case described in fuse_ops.cc: | ||
conv2d | ||
/ | \ | ||
/ | \ | ||
op op op | ||
\ | / | ||
\ | / | ||
elemwise add | ||
| | ||
""" | ||
dshape = (1, 1, 5, 1) | ||
x = relay.var("x", shape=dshape) | ||
y = relay.nn.conv2d(x, relay.var("w1"), | ||
kernel_size=(3, 3), | ||
padding=(1, 1), | ||
channels=1) | ||
|
||
x1 = relay.nn.conv2d(y, relay.var("w2"), | ||
kernel_size=(3, 3), | ||
padding=(1, 1), | ||
channels=1) | ||
x2 = relay.nn.conv2d(y, relay.var("w3"), | ||
kernel_size=(3, 3), | ||
padding=(1, 1), | ||
channels=1) | ||
x3 = relay.nn.conv2d(y, relay.var("w4"), | ||
kernel_size=(3, 3), | ||
padding=(1, 1), | ||
channels=1) | ||
|
||
z = relay.add(x1, x2) | ||
z = relay.add(x3, z) | ||
|
||
return tvm.IRModule.from_expr(z) | ||
|
||
|
||
def get_conv2d(): | ||
x = relay.var("x", shape=(1, 56, 56, 64)) | ||
weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) | ||
y = relay.nn.conv2d(x, weight1, | ||
channels=32, | ||
kernel_size=(3, 3), | ||
padding=(1, 1), | ||
data_layout='NHWC', | ||
kernel_layout='HWIO') | ||
return tvm.IRModule.from_expr(y) | ||
|
||
|
||
def test_extract_identity(): | ||
mod = get_conv2d() | ||
items = relay.analysis.extract_fused_functions(mod) | ||
assert len(items) == 1 | ||
|
||
mod["main"] = mod["main"].with_attr( | ||
"Primitive", tvm.tir.IntImm("int32", 1)) | ||
relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"]) | ||
|
||
|
||
def test_extract_conv_net(): | ||
mod = get_conv_net() | ||
items = relay.analysis.extract_fused_functions(mod) | ||
functions = list(items.values()) | ||
assert len(functions) == 2 | ||
x = functions[0] | ||
y = functions[1] | ||
|
||
def is_conv(func): | ||
conv2d = relay.op.op.get("nn.conv2d") | ||
call_node = func.body | ||
return call_node.op == conv2d | ||
|
||
def is_conv_add(func): | ||
add = relay.op.op.get("add") | ||
call_node = func.body | ||
maybe_conv_module = tvm.IRModule.from_expr(call_node.args[0]) | ||
return call_node.op == add and is_conv(maybe_conv_module["main"]) | ||
|
||
# Function traversal order isn't obvious, so checking both orders is more consistent | ||
assert (is_conv(x) and is_conv_add(y)) or (is_conv_add(x) and is_conv(y)) | ||
|
||
|
||
def test_extract_resnet(): | ||
mod, _params = get_workload() | ||
items = relay.analysis.extract_fused_functions(mod) | ||
assert len(items) == 34 | ||
|
||
|
||
if __name__ == '__main__': | ||
test_extract_identity() | ||
test_extract_conv_net() | ||
test_extract_resnet() |