Skip to content

Commit 84d1eec

Browse files
authored
[RELAY][BYOC] Register pattern tables from external codegens (#5262)
* [RELAY][BYOC] Register pattern tables from external codegens This adds utility functions to support registering and retrieving pattern tables used by MergeComposite for external codegens. Change-Id: I5be165a321440e48b15ff6aff4970e0c67496aaa * Updated DNNL tests to use pattern table mechanism * Removed pattern table standalone test * Change reg to _op
1 parent 6e36da3 commit 84d1eec

File tree

4 files changed

+78
-18
lines changed

4 files changed

+78
-18
lines changed

python/tvm/relay/op/contrib/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@
1616
# under the License.
1717
# pylint: disable=wildcard-import
1818
"""Contrib modules."""
19+
from .register import get_pattern_table, register_pattern_table
20+
1921
from .dnnl import *

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
- The other way is to implement the function by themselves to
3333
check the attributes of the op and decide if it should be offloaded to DNNL.
3434
"""
35-
from ... import op as reg
35+
from ... import expr as _expr
36+
from ... import op as _op
37+
from .register import register_pattern_table
3638

3739

3840
def _register_external_op_helper(op_name, supported=True):
@@ -49,7 +51,7 @@ def _register_external_op_helper(op_name, supported=True):
4951
f : callable
5052
A function that returns if the operator is supported by DNNL.
5153
"""
52-
@reg.register(op_name, "target.dnnl")
54+
@_op.register(op_name, "target.dnnl")
5355
def _func_wrapper(attrs, args):
5456
return supported
5557

@@ -63,3 +65,23 @@ def _func_wrapper(attrs, args):
6365
_register_external_op_helper("add")
6466
_register_external_op_helper("subtract")
6567
_register_external_op_helper("multiply")
68+
69+
70+
def make_pattern(with_bias=True):
71+
data = _expr.var("data")
72+
weight = _expr.var("weight")
73+
bias = _expr.var("bias")
74+
conv = _op.nn.conv2d(data, weight)
75+
if with_bias:
76+
conv_out = _op.add(conv, bias)
77+
else:
78+
conv_out = conv
79+
return _op.nn.relu(conv_out)
80+
81+
82+
@register_pattern_table("dnnl")
83+
def pattern_table():
84+
conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
85+
conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
86+
dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
87+
return dnnl_patterns
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Register utilities for external codegen."""
18+
_PATTERN_TABLES = {}
19+
20+
21+
def register_pattern_table(compiler, table=None):
22+
"""Register a pattern table for an external compiler.
23+
24+
Pattern tables are used to create composite functions.
25+
See the MergeComposite pass.
26+
27+
Parameters
28+
----------
29+
compiler : str
30+
The name of compiler
31+
32+
table : function, optional
33+
A function that returns the pattern table
34+
35+
Returns
36+
-------
37+
fregister : function
38+
Register function if value is not specified.
39+
"""
40+
def _register(t):
41+
"""internal register function"""
42+
_PATTERN_TABLES[compiler] = t()
43+
return t
44+
return _register(table) if table is not None else _register
45+
46+
47+
def get_pattern_table(compiler):
48+
"""Get the pattern table associated with a compiler (if it's registered)."""
49+
return _PATTERN_TABLES[compiler] if compiler in _PATTERN_TABLES else None

tests/python/relay/test_pass_partition_graph.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import sys
2020

2121
import numpy as np
22-
import pytest
2322

2423
import tvm
2524
import tvm.relay.testing
@@ -31,6 +30,7 @@
3130
from tvm.relay.backend import compile_engine
3231
from tvm.relay.expr_functor import ExprMutator
3332
from tvm.relay.op.annotation import compiler_begin, compiler_end
33+
from tvm.relay.op.contrib.register import get_pattern_table
3434
from tvm.relay.build_module import bind_params_by_name
3535

3636

@@ -832,21 +832,8 @@ def expected():
832832

833833

834834
def test_dnnl_fuse():
835-
def make_pattern(with_bias=True):
836-
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
837-
weight = relay.var("weight")
838-
bias = relay.var("bias")
839-
conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
840-
channels=8, padding=(1, 1))
841-
if with_bias:
842-
conv_out = relay.add(conv, bias)
843-
else:
844-
conv_out = conv
845-
return relay.nn.relu(conv_out)
846-
847-
conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
848-
conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
849-
dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
835+
dnnl_patterns = get_pattern_table("dnnl")
836+
conv2d_bias_relu_pat, conv2d_relu_pat = dnnl_patterns
850837

851838
def get_blocks(prefix, data, in_channel, out_channel,
852839
include_bn=True, include_sigmoid=False):

0 commit comments

Comments
 (0)