@@ -19,64 +19,78 @@ def extract_attributes(module):
19
19
return res
20
20
21
21
22
+ scope_name_prog = re .compile (r'^([a-zA-Z0-9_\-]+)/' )
22
23
short_name_prog = re .compile (r'\[([a-zA-Z0-9]+)\]' )
23
24
is_variable = re .compile (r'/([a-zA-Z_0-9]+(?:\.[0-9]+)?)$' )
24
25
25
26
26
27
def get_layer_id (name : str ):
27
28
"""
28
- Takes a name like 'ResNet/Conv2d[conv1]/1504' and converts it back to
29
- the name from named_modules method, e.g. conv1
29
+ Takes a name like 'ResNet/Conv2d[conv1]/1504' and converts it to a shorter version
30
+
30
31
Examples
31
32
1. 'ResNet/Sequential[layer1]/BasicBlock[1]/Conv2d[conv2]/1658'
32
- -> layer1.1.conv2
33
+ -> layer1.1.conv2/1657
33
34
2. 'ResNet/Sequential[layer2]/BasicBlock[0]/BatchNorm2d[bn1]/1714'
34
- -> layer2.0.bn1
35
+ -> layer2.0.bn1/1714
35
36
3. 'ResNet/Sequential[layer1]/BasicBlock[0]/input.4'
36
37
-> layer1.0/input.4
38
+ 4. 'input/input.1'
39
+ -> input-1
40
+ 5. 'output/output.1'
41
+ -> output-1
37
42
"""
38
43
res = short_name_prog .findall (name )
39
44
var = is_variable .search (name )
40
45
if not res :
41
- return name , False
46
+ return name
42
47
if var :
43
- return '.' .join (res ) + '/' + var .group (1 ), True
44
- return '.' .join (res ), False
48
+ return '.' .join (res ) + '/' + var .group (1 )
49
+ return '.' .join (res )
45
50
46
51
47
- def get_short_layer_id (name : str ):
52
+ def get_scope_id (name : str ):
48
53
"""
49
- Takes a name like 'ResNet/Conv2d[conv1]/1504' and converts it back to
50
- the name from named_modules method, e.g. conv1
54
+ Takes a name like 'ResNet/Conv2d[conv1]/1504' and converts it to
55
+ its scope variant, which could be later used for ` named_modules` method.
51
56
Examples
52
57
1. 'ResNet/Sequential[layer1]/BasicBlock[1]/Conv2d[conv2]/1658'
53
- -> layer1.1.conv2
58
+ -> Resnet. layer1.1.conv2
54
59
2. 'ResNet/Sequential[layer2]/BasicBlock[0]/BatchNorm2d[bn1]/1714'
55
- -> layer2.0.bn1
60
+ -> Resnet.layer2.0.bn1
61
+ 2. 'ResNet/Sequential[layer2]/BasicBlock[0]/BatchNorm2d[bn1]/input.2'
62
+ -> Resnet.layer2.0.bn1
56
63
3. 'ResNet/Sequential[layer1]/BasicBlock[0]/input.4'
57
- -> layer1.0
64
+ -> Resnet.layer1.0
65
+ 3. 'ResNet/x.1'
66
+ -> Resnet.x
58
67
"""
59
68
res = short_name_prog .findall (name )
60
69
if not res :
61
- return name
62
- return '.' .join (res )
70
+ # no groups mean its something like Resnet/x.2, which we normalize to Resnet
71
+ return name .split ('/' )[0 ]
72
+
73
+ scope = scope_name_prog .findall (name )
74
+
75
+ return scope [0 ] + '.' + ('.' .join (res ))
63
76
64
77
65
78
def get_pytorch_graph (net , x ):
66
79
names_from_id = dict ()
67
- names_is_variable = dict ()
68
80
nodes_from_id = dict ()
69
81
names_from_debug = dict ()
70
- names_short_from_debug = dict ()
71
- names_to_short = dict ()
82
+ scopes_from_debug = dict ()
83
+ names_to_scope = dict ()
84
+ scope_nodes = dict ()
85
+ # names_to_scope = dict()
72
86
73
87
container_names = dict ()
74
88
known_modules_map = dict ()
75
89
known_modules_name_map = dict ()
76
90
77
91
torch_graph , torch_nodes = build_graph (net , x )
78
92
79
- for name , module in net .named_modules ():
93
+ for name , module in net .named_modules (prefix = type ( net ). __name__ ):
80
94
known_modules_map [module ] = name
81
95
known_modules_name_map [name ] = module
82
96
@@ -95,23 +109,28 @@ def get_parent(name, go_up=1) -> str:
95
109
if node .kind == 'prim::Constant' : continue
96
110
if node .kind == 'prim::GetAttr' : continue
97
111
if node .kind == 'prim::ListConstruct' : continue
112
+ if node .kind == 'aten::t' : continue
98
113
99
- layer_id , is_variable = get_layer_id (node .debugName )
100
- if is_variable :
101
- names_is_variable [layer_id ] = node
114
+ layer_id = get_layer_id (node .debugName )
115
+ scope_id = get_scope_id (node .debugName )
102
116
names_from_id [layer_id ] = node .debugName
103
117
nodes_from_id [layer_id ] = node
104
118
names_from_debug [node .debugName ] = layer_id
105
- names_short_from_debug [node .debugName ] = get_short_layer_id (node .debugName )
106
- names_to_short [layer_id ] = names_short_from_debug [node .debugName ]
119
+ scopes_from_debug [node .debugName ] = scope_id
120
+ names_to_scope [layer_id ] = scopes_from_debug [node .debugName ]
121
+ if scope_id not in scope_nodes :
122
+ scope_nodes [scope_id ] = [layer_id ]
123
+ else :
124
+ scope_nodes [scope_id ].append (layer_id )
125
+ # names_to_scope[layer_id] = get_scope_name(node.debugName)
107
126
108
127
edges = dict ()
109
128
edges_internal = dict ()
110
129
111
130
for node in torch_nodes .values ():
112
131
if node .debugName not in names_from_debug : continue
113
132
layer_id = names_from_debug [node .debugName ]
114
- short_layer_id = names_short_from_debug [node .debugName ]
133
+ short_layer_id = scopes_from_debug [node .debugName ]
115
134
116
135
print (node .debugName , '=>' , layer_id , short_layer_id , node .kind )
117
136
for parent in get_parent_names (layer_id ):
@@ -123,7 +142,7 @@ def get_parent(name, go_up=1) -> str:
123
142
124
143
if input in names_from_debug and layer_id != names_from_debug [input ] \
125
144
and short_layer_id != names_from_debug [input ]:
126
- print (' outgoing' , names_from_debug [input ], names_short_from_debug [input ], input )
145
+ print (' outgoing' , names_from_debug [input ], scopes_from_debug [input ], input )
127
146
# this node points out of itself, so create an edge
128
147
edge_to = names_from_debug [input ]
129
148
@@ -134,12 +153,12 @@ def get_parent(name, go_up=1) -> str:
134
153
135
154
def resolve_edges_to_known_layer (from_layer : str , inputs : Set [str ]) -> List [str ]:
136
155
new_inputs = set ()
137
- short_name = names_to_short [from_layer ] if from_layer in names_to_short else None
156
+ short_name = names_to_scope [from_layer ] if from_layer in names_to_scope else None
138
157
parent_name = get_parent (short_name ) if short_name else None
139
158
140
159
# parent_layer = get_parent(from_layer)
141
160
for input in inputs :
142
- input_short_name = names_to_short [input ] if input in names_to_short else None
161
+ input_short_name = names_to_scope [input ] if input in names_to_scope else None
143
162
144
163
# we skip connection where even the 2. parent is not the same or a child of from_layer.
145
164
# we could make this configurable.
@@ -161,53 +180,10 @@ def resolve_edges_to_known_layer(from_layer: str, inputs: Set[str]) -> List[str]
161
180
162
181
return list (new_inputs )
163
182
164
- # edges_resolved = dict()
165
- # shapes = dict()
166
- # short_name_to_id = dict()
167
-
168
- # # we resolve the edges only from known layers
169
- # for [name, inputs] in edges.items():
170
- # # first name=layer2.0/input.1 => layer2.0
171
- # short_name = name
172
- # if name in names_to_short:
173
- # short_name = names_to_short[name]
174
- #
175
- # if short_name not in known_modules_name_map and name not in names_is_variable: continue
176
- # # if short_name in edges_resolved: continue
177
- #
178
- # shapes[short_name] = nodes_from_id[name].tensor_size
179
- # short_name_to_id[short_name] = name
180
- # edges_resolved[short_name] = resolve_edges_to_known_layer(name, inputs)
181
- #
182
183
deepkit_nodes = []
183
184
184
- # for [name, inputs] in edges_resolved.items():
185
- # module = known_modules_name_map[name]
186
- # node = {
187
- # 'id': name,
188
- # 'label': name,
189
- # 'type': type(module).__name__,
190
- # 'input': inputs,
191
- # 'attributes': extract_attributes(module),
192
- # 'internalInputs': list(edges[short_name_to_id[name]]),
193
- # 'shape': shapes[name]
194
- # }
195
- # deepkit_nodes.append(node)
196
- #
197
- # allowed_torch_kinds = {
198
- # 'add', 'div', 'mul', 'add_', 'div_', 'mul_'
199
- # 'relu', 'sigmoid',
200
- # 'threshold'
201
- # }
202
-
203
- # for torch_node in torch_graph.outputs():
204
- # print(torch_node.debugName, torch_node.kind)
205
- # pass
206
-
207
- # we start from all outputs and go backwards, all visited input nodes are
208
- # collected. We are interested in only nodes that are in the actual graph
209
- # of outputs.
210
185
nodes_names_to_display = set ()
186
+ scopes = dict ()
211
187
212
188
def collect_inputs (inputs ):
213
189
for input in inputs :
@@ -220,51 +196,87 @@ def collect_inputs(inputs):
220
196
nodes_names_to_display .add (name )
221
197
collect_inputs (inputs )
222
198
199
+ graph_inputs = []
200
+ graph_outputs = []
201
+
223
202
for name in nodes_names_to_display :
224
203
inputs = edges [name ] if name in edges else []
225
- # for [name, inputs] in edges.items():
204
+ # for [name, inputs] in edges.items():
226
205
torch_node = nodes_from_id [name ]
227
- short_name = names_to_short [name ]
206
+ scope_name = names_to_scope [name ]
228
207
229
208
filterer_inputs = []
209
+ if name .startswith ('input/input' ):
210
+ graph_inputs .append (name )
211
+ if name .startswith ('output/output' ):
212
+ graph_outputs .append (name )
230
213
231
214
for input in inputs :
232
- second_parent = get_parent (names_to_short [input ], 2 )
233
- if second_parent and not short_name .startswith (second_parent ):
215
+ second_parent = get_parent (names_to_scope [input ], 2 )
216
+ if second_parent and not scope_name .startswith (second_parent ):
234
217
continue
235
218
if input .startswith ('input/input' ):
236
219
filterer_inputs .append (input )
237
220
continue
238
221
if input in edges : filterer_inputs .append (input )
239
222
240
223
attributes = {}
241
- node_type = torch_node .kind
242
- # that only works when we detected that short_name is unique
243
- # means that there is only one name that points to short_name. or should we always pick the latest?
244
- # dunno
245
- # if short_name in known_modules_name_map:
246
- # extract_attributes(known_modules_name_map[short_name])
247
- # node_type = type(known_modules_name_map[short_name]).__name__
224
+ node_type = str (torch_node .kind )
225
+ node_label = name
226
+ op = ''
227
+
228
+ scope_id = scope_name
229
+
230
+ if len (scope_nodes [scope_name ]) == 1 and scope_name in known_modules_name_map :
231
+ # this node is at the same time a scope, since it only has one
232
+ # node.
233
+ node_label = scope_name
234
+ module = known_modules_name_map [scope_name ]
235
+ node_type = type (module ).__name__
236
+ scope_id = get_parent (scope_name )
237
+ attributes = extract_attributes (module )
238
+ else :
239
+ if node_type .startswith ('aten::' ):
240
+ node_type = 'op'
241
+ op = node_type .replace ('aten::' , '' ).replace ('_' , '' )
248
242
249
243
attributes ['torch.debugName' ] = torch_node .debugName
250
244
attributes ['torch.kind' ] = torch_node .kind
251
245
attributes ['torch.inputs' ] = torch_node .inputs
252
246
253
247
node = {
254
248
'id' : name ,
255
- 'label' : name ,
249
+ 'label' : node_label ,
256
250
'type' : node_type ,
251
+ 'op' : op ,
257
252
'input' : filterer_inputs ,
258
253
'attributes' : attributes ,
259
- # 'internalInputs ': list(edges[short_name_to_id[name]]) ,
260
- 'shapes ' : torch_node . tensor_size ,
261
- # 'shape': shapes[name]
254
+ 'recordable ' : False ,
255
+ 'scope ' : scope_id . replace ( '.' , '/' ) ,
256
+ 'shape' : torch_node . tensor_size ,
262
257
}
263
258
deepkit_nodes .append (node )
264
259
260
+ for name , module in known_modules_name_map .items ():
261
+
262
+ # skip modules that are already added as nodes
263
+ if name in scope_nodes and len (scope_nodes [name ]) == 1 : continue
264
+
265
+ scope_id = name .replace ('.' , '/' )
266
+ scope = {
267
+ 'id' : scope_id ,
268
+ 'label' : scope_id ,
269
+ 'type' : type (module ).__name__ ,
270
+ 'recordable' : True ,
271
+ 'attributes' : extract_attributes (module )
272
+ }
273
+ scopes [scope_id ] = scope
265
274
266
275
graph = {
267
- 'nodes' : deepkit_nodes
276
+ 'nodes' : deepkit_nodes ,
277
+ 'scopes' : scopes ,
278
+ 'inputs' : graph_inputs ,
279
+ 'outputs' : graph_outputs ,
268
280
}
269
281
270
282
return graph
0 commit comments