Skip to content

Commit

Permalink
[Frontend][Onnx] Simplify onnx input since name accesses are not reli…
Browse files Browse the repository at this point in the history
…able. (apache#8867)

* Simplify onnx input since name accesses are no longer supported.

* move Celu importer.
  • Loading branch information
Josh Fromm authored and Andrew Zhao Luo committed Sep 1, 2021
1 parent a5cb1a9 commit d5c699c
Showing 1 changed file with 22 additions and 60 deletions.
82 changes: 22 additions & 60 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,54 +64,16 @@
}


class onnx_input:
"""Dual purpose list or dictionary access object."""

def __init__(self):
self.input_keys = []
self.input_dict = {}
class onnx_input(list):
"""A helper extension to list that returns None for out of bound indices."""

def __getitem__(self, item):
if isinstance(item, int):
if item > (len(self.input_keys) - 1):
return None
return self.input_dict[self.input_keys[item]]
if isinstance(item, str):
if item not in self.input_keys:
return None
return self.input_dict[item]
if isinstance(item, slice):
keys = self.input_keys[item]
return [self.input_dict[key] for key in keys]

raise ValueError("Only integer, string, and slice accesses allowed.")

def __setitem__(self, item, value):
indices = list(range(item.stop)[item])
return [self[i] for i in indices]
if isinstance(item, int):
self.input_dict[self.input_keys[item]] = value
elif isinstance(item, str):
self.input_keys.append(item)
self.input_dict[item] = value
else:
raise ValueError("Only integer and string indexed writes allowed.")

def keys(self):
return self.input_keys

def __len__(self):
return len(self.input_keys)

def __iter__(self):
self.n = 0
return self

def __next__(self):
if self.n < len(self.input_keys):
output = self.input_dict[self.input_keys[self.n]]
self.n += 1
return output

raise StopIteration
return list(self)[item] if item < len(self) else None
raise TypeError("list indices must be integers or slices, not %s" % type(item).__name__)


def get_numpy(tensor_proto):
Expand Down Expand Up @@ -2673,6 +2635,19 @@ def _impl_v10(cls, inputs, attr, params):
return isinf


class Celu(OnnxOpConverter):
"""Operator convereter for celu"""

@classmethod
def _impl_v12(cls, inputs, attr, params):
x = inputs[0]
dtype = infer_type(x).checked_type.dtype
alpha = _op.const(attr.get("alpha", 1.0), dtype)
zero = _op.const(0, dtype)
one = _op.const(1, dtype)
return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one))


class MaxRoiPool(OnnxOpConverter):
"""Operator converter for MaxRoiPool."""

Expand Down Expand Up @@ -3881,13 +3856,13 @@ def from_onnx(self, graph, opset, get_output_expr=False):
for node in graph.node:
op_name = node.op_type
attr = self._parse_attr(node.attribute)
# Create and populate onnx input object.
# Create and populate input list.
inputs = onnx_input()
for i in node.input:
if i != "":
inputs[i] = self._nodes[self._renames.get(i, i)]
inputs.append(self._nodes[self._renames.get(i, i)])
else:
inputs[i] = None
inputs.append(None)
i_name = self._parse_value_proto(node)
node_output = self._fix_outputs(op_name, node.output)
attr["tvm_custom"] = {}
Expand Down Expand Up @@ -4040,19 +4015,6 @@ def _fix_outputs(self, op_name, outputs):
return outputs


class Celu(OnnxOpConverter):
"""Operator convereter for celu"""

@classmethod
def _impl_v12(cls, inputs, attr, params):
x = inputs[0]
dtype = infer_type(x).checked_type.dtype
alpha = _op.const(attr.get("alpha", 1.0), dtype)
zero = _op.const(0, dtype)
one = _op.const(1, dtype)
return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one))


def from_onnx(
model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None
):
Expand Down

0 comments on commit d5c699c

Please sign in to comment.