Skip to content

Commit 82f2cc2

Browse files
committed
send correct nodes, scopes, input and outputs, with ops
necessary to display the full debugger graph.
1 parent 4e5d24c commit 82f2cc2

File tree

1 file changed

+100
-88
lines changed

1 file changed

+100
-88
lines changed

deepkit/pytorch.py

+100-88
Original file line numberDiff line numberDiff line change
@@ -19,64 +19,78 @@ def extract_attributes(module):
1919
return res
2020

2121

22+
scope_name_prog = re.compile(r'^([a-zA-Z0-9_\-]+)/')
2223
short_name_prog = re.compile(r'\[([a-zA-Z0-9]+)\]')
2324
is_variable = re.compile(r'/([a-zA-Z_0-9]+(?:\.[0-9]+)?)$')
2425

2526

2627
def get_layer_id(name: str):
2728
"""
28-
Takes a name like 'ResNet/Conv2d[conv1]/1504' and converts it back to
29-
the name from named_modules method, e.g. conv1
29+
Takes a name like 'ResNet/Conv2d[conv1]/1504' and converts it to a shorter version
30+
3031
Examples
3132
1. 'ResNet/Sequential[layer1]/BasicBlock[1]/Conv2d[conv2]/1658'
32-
-> layer1.1.conv2
33+
-> layer1.1.conv2/1657
3334
2. 'ResNet/Sequential[layer2]/BasicBlock[0]/BatchNorm2d[bn1]/1714'
34-
-> layer2.0.bn1
35+
-> layer2.0.bn1/1714
3536
3. 'ResNet/Sequential[layer1]/BasicBlock[0]/input.4'
3637
-> layer1.0/input.4
38+
4. 'input/input.1'
39+
-> input-1
40+
5. 'output/output.1'
41+
-> output-1
3742
"""
3843
res = short_name_prog.findall(name)
3944
var = is_variable.search(name)
4045
if not res:
41-
return name, False
46+
return name
4247
if var:
43-
return '.'.join(res) + '/' + var.group(1), True
44-
return '.'.join(res), False
48+
return '.'.join(res) + '/' + var.group(1)
49+
return '.'.join(res)
4550

4651

47-
def get_short_layer_id(name: str):
52+
def get_scope_id(name: str):
4853
"""
49-
Takes a name like 'ResNet/Conv2d[conv1]/1504' and converts it back to
50-
the name from named_modules method, e.g. conv1
54+
Takes a name like 'ResNet/Conv2d[conv1]/1504' and converts it to
55+
its scope variant, which could be later used for `named_modules` method.
5156
Examples
5257
1. 'ResNet/Sequential[layer1]/BasicBlock[1]/Conv2d[conv2]/1658'
53-
-> layer1.1.conv2
58+
-> Resnet.layer1.1.conv2
5459
2. 'ResNet/Sequential[layer2]/BasicBlock[0]/BatchNorm2d[bn1]/1714'
55-
-> layer2.0.bn1
60+
-> Resnet.layer2.0.bn1
61+
2. 'ResNet/Sequential[layer2]/BasicBlock[0]/BatchNorm2d[bn1]/input.2'
62+
-> Resnet.layer2.0.bn1
5663
3. 'ResNet/Sequential[layer1]/BasicBlock[0]/input.4'
57-
-> layer1.0
64+
-> Resnet.layer1.0
65+
3. 'ResNet/x.1'
66+
-> Resnet.x
5867
"""
5968
res = short_name_prog.findall(name)
6069
if not res:
61-
return name
62-
return '.'.join(res)
70+
# no groups mean its something like Resnet/x.2, which we normalize to Resnet
71+
return name.split('/')[0]
72+
73+
scope = scope_name_prog.findall(name)
74+
75+
return scope[0] + '.' + ('.'.join(res))
6376

6477

6578
def get_pytorch_graph(net, x):
6679
names_from_id = dict()
67-
names_is_variable = dict()
6880
nodes_from_id = dict()
6981
names_from_debug = dict()
70-
names_short_from_debug = dict()
71-
names_to_short = dict()
82+
scopes_from_debug = dict()
83+
names_to_scope = dict()
84+
scope_nodes = dict()
85+
# names_to_scope = dict()
7286

7387
container_names = dict()
7488
known_modules_map = dict()
7589
known_modules_name_map = dict()
7690

7791
torch_graph, torch_nodes = build_graph(net, x)
7892

79-
for name, module in net.named_modules():
93+
for name, module in net.named_modules(prefix=type(net).__name__):
8094
known_modules_map[module] = name
8195
known_modules_name_map[name] = module
8296

@@ -95,23 +109,28 @@ def get_parent(name, go_up=1) -> str:
95109
if node.kind == 'prim::Constant': continue
96110
if node.kind == 'prim::GetAttr': continue
97111
if node.kind == 'prim::ListConstruct': continue
112+
if node.kind == 'aten::t': continue
98113

99-
layer_id, is_variable = get_layer_id(node.debugName)
100-
if is_variable:
101-
names_is_variable[layer_id] = node
114+
layer_id = get_layer_id(node.debugName)
115+
scope_id = get_scope_id(node.debugName)
102116
names_from_id[layer_id] = node.debugName
103117
nodes_from_id[layer_id] = node
104118
names_from_debug[node.debugName] = layer_id
105-
names_short_from_debug[node.debugName] = get_short_layer_id(node.debugName)
106-
names_to_short[layer_id] = names_short_from_debug[node.debugName]
119+
scopes_from_debug[node.debugName] = scope_id
120+
names_to_scope[layer_id] = scopes_from_debug[node.debugName]
121+
if scope_id not in scope_nodes:
122+
scope_nodes[scope_id] = [layer_id]
123+
else:
124+
scope_nodes[scope_id].append(layer_id)
125+
# names_to_scope[layer_id] = get_scope_name(node.debugName)
107126

108127
edges = dict()
109128
edges_internal = dict()
110129

111130
for node in torch_nodes.values():
112131
if node.debugName not in names_from_debug: continue
113132
layer_id = names_from_debug[node.debugName]
114-
short_layer_id = names_short_from_debug[node.debugName]
133+
short_layer_id = scopes_from_debug[node.debugName]
115134

116135
print(node.debugName, '=>', layer_id, short_layer_id, node.kind)
117136
for parent in get_parent_names(layer_id):
@@ -123,7 +142,7 @@ def get_parent(name, go_up=1) -> str:
123142

124143
if input in names_from_debug and layer_id != names_from_debug[input] \
125144
and short_layer_id != names_from_debug[input]:
126-
print(' outgoing', names_from_debug[input], names_short_from_debug[input], input)
145+
print(' outgoing', names_from_debug[input], scopes_from_debug[input], input)
127146
# this node points out of itself, so create an edge
128147
edge_to = names_from_debug[input]
129148

@@ -134,12 +153,12 @@ def get_parent(name, go_up=1) -> str:
134153

135154
def resolve_edges_to_known_layer(from_layer: str, inputs: Set[str]) -> List[str]:
136155
new_inputs = set()
137-
short_name = names_to_short[from_layer] if from_layer in names_to_short else None
156+
short_name = names_to_scope[from_layer] if from_layer in names_to_scope else None
138157
parent_name = get_parent(short_name) if short_name else None
139158

140159
# parent_layer = get_parent(from_layer)
141160
for input in inputs:
142-
input_short_name = names_to_short[input] if input in names_to_short else None
161+
input_short_name = names_to_scope[input] if input in names_to_scope else None
143162

144163
# we skip connection where even the 2. parent is not the same or a child of from_layer.
145164
# we could make this configurable.
@@ -161,53 +180,10 @@ def resolve_edges_to_known_layer(from_layer: str, inputs: Set[str]) -> List[str]
161180

162181
return list(new_inputs)
163182

164-
# edges_resolved = dict()
165-
# shapes = dict()
166-
# short_name_to_id = dict()
167-
168-
# # we resolve the edges only from known layers
169-
# for [name, inputs] in edges.items():
170-
# # first name=layer2.0/input.1 => layer2.0
171-
# short_name = name
172-
# if name in names_to_short:
173-
# short_name = names_to_short[name]
174-
#
175-
# if short_name not in known_modules_name_map and name not in names_is_variable: continue
176-
# # if short_name in edges_resolved: continue
177-
#
178-
# shapes[short_name] = nodes_from_id[name].tensor_size
179-
# short_name_to_id[short_name] = name
180-
# edges_resolved[short_name] = resolve_edges_to_known_layer(name, inputs)
181-
#
182183
deepkit_nodes = []
183184

184-
# for [name, inputs] in edges_resolved.items():
185-
# module = known_modules_name_map[name]
186-
# node = {
187-
# 'id': name,
188-
# 'label': name,
189-
# 'type': type(module).__name__,
190-
# 'input': inputs,
191-
# 'attributes': extract_attributes(module),
192-
# 'internalInputs': list(edges[short_name_to_id[name]]),
193-
# 'shape': shapes[name]
194-
# }
195-
# deepkit_nodes.append(node)
196-
#
197-
# allowed_torch_kinds = {
198-
# 'add', 'div', 'mul', 'add_', 'div_', 'mul_'
199-
# 'relu', 'sigmoid',
200-
# 'threshold'
201-
# }
202-
203-
# for torch_node in torch_graph.outputs():
204-
# print(torch_node.debugName, torch_node.kind)
205-
# pass
206-
207-
# we start from all outputs and go backwards, all visited input nodes are
208-
# collected. We are interested in only nodes that are in the actual graph
209-
# of outputs.
210185
nodes_names_to_display = set()
186+
scopes = dict()
211187

212188
def collect_inputs(inputs):
213189
for input in inputs:
@@ -220,51 +196,87 @@ def collect_inputs(inputs):
220196
nodes_names_to_display.add(name)
221197
collect_inputs(inputs)
222198

199+
graph_inputs = []
200+
graph_outputs = []
201+
223202
for name in nodes_names_to_display:
224203
inputs = edges[name] if name in edges else []
225-
# for [name, inputs] in edges.items():
204+
# for [name, inputs] in edges.items():
226205
torch_node = nodes_from_id[name]
227-
short_name = names_to_short[name]
206+
scope_name = names_to_scope[name]
228207

229208
filterer_inputs = []
209+
if name.startswith('input/input'):
210+
graph_inputs.append(name)
211+
if name.startswith('output/output'):
212+
graph_outputs.append(name)
230213

231214
for input in inputs:
232-
second_parent = get_parent(names_to_short[input], 2)
233-
if second_parent and not short_name.startswith(second_parent):
215+
second_parent = get_parent(names_to_scope[input], 2)
216+
if second_parent and not scope_name.startswith(second_parent):
234217
continue
235218
if input.startswith('input/input'):
236219
filterer_inputs.append(input)
237220
continue
238221
if input in edges: filterer_inputs.append(input)
239222

240223
attributes = {}
241-
node_type = torch_node.kind
242-
# that only works when we detected that short_name is unique
243-
# means that there is only one name that points to short_name. or should we always pick the latest?
244-
# dunno
245-
# if short_name in known_modules_name_map:
246-
# extract_attributes(known_modules_name_map[short_name])
247-
# node_type = type(known_modules_name_map[short_name]).__name__
224+
node_type = str(torch_node.kind)
225+
node_label = name
226+
op = ''
227+
228+
scope_id = scope_name
229+
230+
if len(scope_nodes[scope_name]) == 1 and scope_name in known_modules_name_map:
231+
# this node is at the same time a scope, since it only has one
232+
# node.
233+
node_label = scope_name
234+
module = known_modules_name_map[scope_name]
235+
node_type = type(module).__name__
236+
scope_id = get_parent(scope_name)
237+
attributes = extract_attributes(module)
238+
else:
239+
if node_type.startswith('aten::'):
240+
node_type = 'op'
241+
op = node_type.replace('aten::', '').replace('_', '')
248242

249243
attributes['torch.debugName'] = torch_node.debugName
250244
attributes['torch.kind'] = torch_node.kind
251245
attributes['torch.inputs'] = torch_node.inputs
252246

253247
node = {
254248
'id': name,
255-
'label': name,
249+
'label': node_label,
256250
'type': node_type,
251+
'op': op,
257252
'input': filterer_inputs,
258253
'attributes': attributes,
259-
# 'internalInputs': list(edges[short_name_to_id[name]]),
260-
'shapes': torch_node.tensor_size,
261-
# 'shape': shapes[name]
254+
'recordable': False,
255+
'scope': scope_id.replace('.', '/'),
256+
'shape': torch_node.tensor_size,
262257
}
263258
deepkit_nodes.append(node)
264259

260+
for name, module in known_modules_name_map.items():
261+
262+
# skip modules that are already added as nodes
263+
if name in scope_nodes and len(scope_nodes[name]) == 1: continue
264+
265+
scope_id = name.replace('.', '/')
266+
scope = {
267+
'id': scope_id,
268+
'label': scope_id,
269+
'type': type(module).__name__,
270+
'recordable': True,
271+
'attributes': extract_attributes(module)
272+
}
273+
scopes[scope_id] = scope
265274

266275
graph = {
267-
'nodes': deepkit_nodes
276+
'nodes': deepkit_nodes,
277+
'scopes': scopes,
278+
'inputs': graph_inputs,
279+
'outputs': graph_outputs,
268280
}
269281

270282
return graph

0 commit comments

Comments
 (0)