forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_onnx_graph.py
64 lines (54 loc) · 1.87 KB
/
_onnx_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.versions_pb2 import VersionDef
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
def load_onnx_graph(fname):
import onnx
m = onnx.load(fname)
g = m.graph
return parse(g)
def parse(graph):
nodes_proto = []
nodes = []
import itertools
for node in itertools.chain(graph.input, graph.output):
nodes_proto.append(node)
for node in nodes_proto:
print(node.name)
shapeproto = TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=d.dim_value)
for d in node.type.tensor_type.shape.dim
]
)
nodes.append(
NodeDef(
name=node.name.encode(encoding="utf_8"),
op="Variable",
input=[],
attr={
"dtype": AttrValue(type=node.type.tensor_type.elem_type),
"shape": AttrValue(shape=shapeproto),
},
)
)
for node in graph.node:
_attr = []
for s in node.attribute:
_attr.append(" = ".join([str(f[1]) for f in s.ListFields()]))
attr = ", ".join(_attr).encode(encoding="utf_8")
print(node.output[0])
nodes.append(
NodeDef(
name=node.output[0].encode(encoding="utf_8"),
op=node.op_type,
input=node.input,
attr={"parameters": AttrValue(s=attr)},
)
)
# two pass token replacement, appends opname to object id
mapping = {}
for node in nodes:
mapping[node.name] = node.op + "_" + node.name
return GraphDef(node=nodes, versions=VersionDef(producer=22))