diff --git a/python/tvm/relay/analysis.py b/python/tvm/relay/analysis.py index fc4a03753422..198e0a3bf9eb 100644 --- a/python/tvm/relay/analysis.py +++ b/python/tvm/relay/analysis.py @@ -407,3 +407,25 @@ def structural_hash(value): msg = ("found value of type {0} expected" + "relay.Expr or relay.Type").format(type(value)) raise TypeError(msg) + + +def extract_fused_functions(mod): + """Pass to extract IRModule of only fused primitive functions. + + The ExtractFusedFunctions pass invokes SimplifyInference, FuseOps(3), + and ExtractFusedFunctions in that order + + Parameters + ---------- + mod : tvm.relay.IRModule + + Returns + ------- + ret : Dict[int, tvm.relay.expr.Function] + A module containing only fused primitive functions + """ + ret_mod = _analysis.ExtractFusedFunctions()(mod) + ret = {} + for hash_, func in ret_mod.functions.items(): + ret[hash_] = func + return ret diff --git a/src/relay/analysis/extract_fused_functions.cc b/src/relay/analysis/extract_fused_functions.cc new file mode 100644 index 000000000000..3667d8a47826 --- /dev/null +++ b/src/relay/analysis/extract_fused_functions.cc @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file extract_fused_functions.cc + * \brief Apply fusion and extract fused primitive functions from an IRModule + */ +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +class FusedFunctionExtractorWrapper : private ExprVisitor { + public: + explicit FusedFunctionExtractorWrapper(const IRModule& mod) : mod_(mod) {} + + IRModule Extract() { + VisitExpr(this->mod_->Lookup("main")); + + auto functions = Map(); + for (auto pair : this->functions) { + functions.Set(GlobalVar(pair.first), pair.second); + } + + this->mod_->functions = functions; + return this->mod_; + } + + private: + const IRModule mod_; + // This is not simply Map because GlobalVar doesn't + // have the desired equals property + Map functions; + + void VisitExpr_(const FunctionNode* n) final { + if (n->HasNonzeroAttr(attr::kPrimitive)) { + // Add function to functions, keyed by function hash string + Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs); + size_t hash_ = StructuralHash()(func); + this->functions.Set(std::to_string(hash_), func); + } + + ExprVisitor::VisitExpr_(n); + } +}; + +namespace transform { + +Pass ExtractFusedFunctions() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return FusedFunctionExtractorWrapper(m).Extract(); }; + auto fused_function_extractor_pass = CreateModulePass(pass_func, 1, "ExtractFusedFunctions", {}); + + return Sequential({SimplifyInference(), FuseOps(3), fused_function_extractor_pass}, + "ExtractFusedFunctions"); +} + +TVM_REGISTER_GLOBAL("relay._analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_analysis_extract_fused_functions.py b/tests/python/relay/test_analysis_extract_fused_functions.py new file mode 100644 index 000000000000..1a70ef174233 --- /dev/null +++ b/tests/python/relay/test_analysis_extract_fused_functions.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test function extraction""" +import tvm +from tvm import relay +from tvm.relay.testing.resnet import get_workload + + +def get_conv_net(): + """This gets the net for a case described in fuse_ops.cc: + + conv2d + / | \ + / | \ + op op op + \ | / + \ | / + elemwise add + | + """ + dshape = (1, 1, 5, 1) + x = relay.var("x", shape=dshape) + y = relay.nn.conv2d(x, relay.var("w1"), + kernel_size=(3, 3), + padding=(1, 1), + channels=1) + + x1 = relay.nn.conv2d(y, relay.var("w2"), + kernel_size=(3, 3), + padding=(1, 1), + channels=1) + x2 = relay.nn.conv2d(y, relay.var("w3"), + kernel_size=(3, 3), + padding=(1, 1), + channels=1) + x3 = relay.nn.conv2d(y, relay.var("w4"), + kernel_size=(3, 3), + padding=(1, 1), + channels=1) + + z = relay.add(x1, x2) + z = relay.add(x3, z) + + return tvm.IRModule.from_expr(z) + + +def get_conv2d(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + return tvm.IRModule.from_expr(y) + + +def test_extract_identity(): + mod = get_conv2d() + items = relay.analysis.extract_fused_functions(mod) + assert len(items) == 1 + + mod["main"] = mod["main"].with_attr( + "Primitive", tvm.tir.IntImm("int32", 1)) + relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"]) + + +def test_extract_conv_net(): + mod = get_conv_net() + items = relay.analysis.extract_fused_functions(mod) + functions = list(items.values()) + assert len(functions) == 2 + x = functions[0] + y = functions[1] + + def is_conv(func): + conv2d = relay.op.op.get("nn.conv2d") + call_node = func.body + return call_node.op == conv2d + + def is_conv_add(func): + add = relay.op.op.get("add") + call_node = func.body + maybe_conv_module = tvm.IRModule.from_expr(call_node.args[0]) + return call_node.op == add and is_conv(maybe_conv_module["main"]) + + # Function traversal order isn't obvious, so checking both orders is more consistent + assert (is_conv(x) and is_conv_add(y)) or (is_conv_add(x) and is_conv(y)) + + +def test_extract_resnet(): + mod, _params = get_workload() + items = relay.analysis.extract_fused_functions(mod) + assert len(items) == 34 + + +if __name__ == '__main__': + test_extract_identity() + test_extract_conv_net() + test_extract_resnet()