Skip to content

Commit d50fad2

Browse files
committed
[RELAY][BYOC] Register pattern tables from external codegens
This adds utility functions to support registering and retrieving pattern tables used by MergeComposite for external codegens. Change-Id: I5be165a321440e48b15ff6aff4970e0c67496aaa
1 parent 41b8fd1 commit d50fad2

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

python/tvm/relay/op/contrib/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@
1616
# under the License.
1717
# pylint: disable=wildcard-import
1818
"""Contrib modules."""
19+
from .register import get_pattern_table, register_pattern_table
20+
1921
from .dnnl import *
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Register utilities for external codegen."""
18+
pattern_tables = {}
19+
20+
21+
def register_pattern_table(compiler, table=None):
22+
"""Register a pattern table for an external compiler.
23+
24+
Pattern tables are used to create composite functions.
25+
See the MergeComposite pass.
26+
27+
Parameters
28+
----------
29+
compiler : str
30+
The name of compiler
31+
32+
table : function, optional
33+
A function that returns the pattern table
34+
35+
Returns
36+
-------
37+
fregister : function
38+
Register function if value is not specified.
39+
"""
40+
def _register(t):
41+
"""internal register function"""
42+
pattern_tables[compiler] = t()
43+
return t
44+
return _register(table) if table is not None else _register
45+
46+
47+
def get_pattern_table(compiler):
48+
"""Get the pattern table associated with a compiler (if it's registered)."""
49+
return pattern_tables[compiler] if compiler in pattern_tables else None
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Unit test for pattern table registry (BYOC)."""
18+
from tvm.relay.op.contrib import get_pattern_table, register_pattern_table
19+
from tvm import relay
20+
21+
22+
@register_pattern_table("test_pattern_table")
23+
def pattern_table():
24+
def _make_add_relu_pattern():
25+
x = relay.var('x')
26+
y = relay.var('y')
27+
add_node = relay.add(x, y)
28+
r = relay.nn.relu(add_node)
29+
return r
30+
31+
def _check_add_relu_pattern():
32+
return True
33+
34+
return [
35+
("test_pattern_table.add_relu", _make_add_relu_pattern(), _check_add_relu_pattern)
36+
]
37+
38+
39+
def test_retrieve_pattern_table():
40+
table = get_pattern_table("test_pattern_table")
41+
assert table[0][0] == "test_pattern_table.add_relu"
42+
43+
44+
if __name__ == "__main__":
45+
test_retrieve_pattern_table()

0 commit comments

Comments
 (0)