Skip to content

Commit 75b00e8

Browse files
author
Ashok Sudarsanam
committed
TVM changes to support introduction and typing of new custom operations.
Merged in SIM-6711 (pull request apache#36) Approved-by: Mikael Sevenier Approved-by: Joey Chou
1 parent ea3337c commit 75b00e8

File tree

5 files changed

+822
-0
lines changed

5 files changed

+822
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, too-many-lines
18+
"""Custom operation configuration interface."""
19+
from typing import List, Dict, Callable
20+
from dataclasses import dataclass
21+
from tvm.ir import TensorType
22+
import json
23+
24+
25+
@dataclass()
26+
class CustomOpConfigInfo():
27+
"""
28+
Dataclass that contains configuration information for a custom operation.
29+
This dataclass contains the following fields:
30+
1. code: a string that contains the corresponding C code implementation.
31+
2. func_name: the name of the function in the C code that implements the
32+
custom operation.
33+
3. datatype: a string that specifies the underlying tensor datatype that
34+
is assumed by the C code implementation. Currently supported values are
35+
“int8”, “float”, and “double”.
36+
4. type_func: a Python function that returns the type of the custom opera-
37+
tion, based on the types of the input tensor(s) and relevant attributes.
38+
5. compiler_flags: a string that contains custom operation-specific flags
39+
for the target compiler.
40+
"""
41+
42+
code: str
43+
func_name: str
44+
datatype: str
45+
type_func: Callable[..., TensorType]
46+
compiler_flags: str
47+
48+
49+
class CustomOperationConfig:
50+
"""
51+
Singleton class that contains configuration information for each custom
52+
operation that exists in an ML model. This information is used during
53+
the construction and typing of custom operations.
54+
"""
55+
56+
__instance = None
57+
config_dict: Dict[str, CustomOpConfigInfo] = dict()
58+
59+
@staticmethod
60+
def get_instance():
61+
if CustomOperationConfig.__instance == None:
62+
CustomOperationConfig()
63+
return CustomOperationConfig.__instance
64+
65+
def __init__(self):
66+
if CustomOperationConfig.__instance != None:
67+
raise Exception("CustomOperationConfig class is a singleton.")
68+
else:
69+
CustomOperationConfig.__instance = self
70+
71+
def add_config_for_custom_op(self, custom_op_name: str,
72+
custom_op_config_info: CustomOpConfigInfo):
73+
self.config_dict[custom_op_name] = custom_op_config_info
74+
75+
def get_config_for_custom_op(self, custom_op_name: str) -> CustomOpConfigInfo:
76+
return self.config_dict[custom_op_name]
77+
78+
def get_custom_ops(self) -> List[str]:
79+
return list(self.config_dict.keys())

python/tvm/relay/op/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from __future__ import absolute_import as _abs
2020
from .nn import *
2121
from . import _nn
22+
from . import custom_operation
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, too-many-lines
18+
"""Neural network operations for custom ops."""
19+
from tvm.relay import expr
20+
from tvm.ir import Attrs
21+
22+
from . import _make
23+
from typing import List, Tuple
24+
from tvm.custom_operation_config import (
25+
CustomOpConfigInfo, CustomOperationConfig
26+
)
27+
import tvm._ffi
28+
import json
29+
30+
31+
MAX_TENSOR_INPUTS = 5
32+
33+
34+
def custom_op(inputs, input_types, name, code, func_name, datatype, compiler_flags):
35+
"""
36+
Create a Relay IR node for the custom operation. Specifically, a
37+
CallNode around an operator nn.custom_op_{i} is returned, where {i}
38+
denotes the total number of input tensor operands in the custom
39+
operation. The number of input tensor operands cannot exceed 5.
40+
41+
The inputs to a custom operation may also include constant values
42+
that represent attributes of the operation. Each attribute must
43+
be a string, an integer, a floating point value, a list of integers,
44+
or a list of floating-point values.
45+
46+
In the custom operation specification in the ML network, the tensor
47+
operands must appear first, followed by the constant attributes.
48+
"""
49+
50+
# Partition the inputs into tensor operands and constant attributes.
51+
tensor_inputs = []
52+
constant_attrs = []
53+
for input in inputs:
54+
if isinstance(input, tvm.relay.expr.ExprWithOp):
55+
if len(constant_attrs) == 0:
56+
tensor_inputs.append(input)
57+
else:
58+
raise AssertionError("Tensor operands must precede constant attributes.")
59+
elif is_valid_attribute(input):
60+
constant_attrs.append(input)
61+
else:
62+
raise AssertionError(f"Input {input} is neither a tensor nor a constant attribute.")
63+
64+
# Store all attributes of the custom operation in a dictionary.
65+
# The following string attributes are common to all custom
66+
# operations:
67+
# 1. Custom operation name.
68+
# 2. C code implementation.
69+
# 3. C code function name.
70+
# 4. C code datatype.
71+
# 5. Operation-specific compiler flags.
72+
#
73+
# A custom operation may also have constant attributes that are
74+
# specific to it.
75+
custom_op_attrs = {
76+
"name": name,
77+
"code": code,
78+
"func_name": func_name,
79+
"datatype": datatype,
80+
"compiler_flags": compiler_flags,
81+
"constant_attrs": constant_attrs
82+
}
83+
84+
custom_op_attr_str = json.dumps(custom_op_attrs)
85+
86+
if len(tensor_inputs) == 1:
87+
return _make.custom_op_1(*tensor_inputs, custom_op_attr_str)
88+
elif len(tensor_inputs) == 2:
89+
return _make.custom_op_2(*tensor_inputs, custom_op_attr_str)
90+
elif len(tensor_inputs) == 3:
91+
return _make.custom_op_3(*tensor_inputs, custom_op_attr_str)
92+
elif len(tensor_inputs) == 4:
93+
return _make.custom_op_4(*tensor_inputs, custom_op_attr_str)
94+
elif len(tensor_inputs) == 5:
95+
return _make.custom_op_5(*tensor_inputs, custom_op_attr_str)
96+
else:
97+
msg = "Unsupported number of input tensor arguments (%d)." % (len(tensor_inputs))
98+
raise AssertionError(msg)
99+
100+
101+
def is_valid_attribute(input):
102+
"""
103+
Returns True if the input operand is a string, an integer, a floating
104+
point number, a list of integers, or a list of floating-point numbers.
105+
"""
106+
107+
input_type = type(input)
108+
if input_type == str or input_type == int or input_type == float:
109+
return True
110+
111+
if input_type == list and type(input[0]) in [int, float]:
112+
for elem in input:
113+
if type(elem) != type(input[0]):
114+
return False
115+
return True
116+
117+
return False
118+
119+
120+
@tvm._ffi.register_func("relay.op.nn.custom_op_type_func")
121+
def custom_op_type_func(types, num_inputs, attrs):
122+
"""
123+
Return the type of the specified custom operation, based on the
124+
input types and constant attribute values. This function is
125+
invoked by the registered add_type_rel() function in the C++ code.
126+
"""
127+
128+
custom_op_attrs = json.loads(attrs.custom_op_attrs)
129+
custom_op_name = custom_op_attrs["name"]
130+
constant_attrs = custom_op_attrs["constant_attrs"]
131+
132+
# Get the typing function associated with the custom operation.
133+
custom_op_config = CustomOperationConfig.get_instance()
134+
config_info = custom_op_config.get_config_for_custom_op(custom_op_name)
135+
type_func = config_info.type_func
136+
137+
msg = f"Unsupported number of input tensor arguments {num_inputs} (max = {MAX_TENSOR_INPUTS})"
138+
assert 0 < num_inputs <= MAX_TENSOR_INPUTS, msg
139+
140+
input_args = tuple([types[i] for i in range(num_inputs)])
141+
return type_func(*input_args, *constant_attrs)
142+
143+
144+
@tvm._ffi.register_object("relay.attrs.CustomOpAttrs")
145+
class CustomOpAttrs(Attrs):
146+
"""Attributes for nn custom operations"""
147+
148+
149+
def make_custom_op(name, code, func_name, datatype, compiler_flags):
150+
def custom_op_func(inputs, input_types):
151+
return custom_op(inputs, input_types, name,
152+
code, func_name, datatype,
153+
compiler_flags)
154+
155+
return custom_op_func
156+
157+
158+
def get_convert_map_from_custom_op_config():
159+
"""
160+
Construct a mapping from custom operation name to Relay IR
161+
creation function. This mapping will get inserted into
162+
the front-end's operator conversion map.
163+
"""
164+
165+
convert_map = {}
166+
custom_op_config = CustomOperationConfig.get_instance()
167+
custom_op_names = custom_op_config.get_custom_ops()
168+
169+
for custom_op_name in custom_op_names:
170+
config_info = custom_op_config.get_config_for_custom_op(custom_op_name)
171+
code = config_info.code
172+
func_name = config_info.func_name
173+
datatype = config_info.datatype
174+
compiler_flags = config_info.compiler_flags
175+
176+
convert_map[custom_op_name] = make_custom_op(custom_op_name, code,
177+
func_name, datatype,
178+
compiler_flags)
179+
180+
return convert_map

0 commit comments

Comments
 (0)