Skip to content

Commit db5c6fb

Browse files
committed
fixed race condition in reconnect and added PoC of Pytorch debugger support
1 parent e464a72 commit db5c6fb

File tree

7 files changed

+746
-11
lines changed

7 files changed

+746
-11
lines changed

deepkit/client.py

+11
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,14 @@ async def _connect_job(self, host: str, port: int, id: str, token: str):
361361
return
362362

363363
self.loop.create_task(self.handle_messages(self.connection))
364+
365+
# on reconnect this sends all patches and queued messages as well before authentication
366+
# is done, so we temp save those, do the authentication, and then let the client
367+
# sync the queued stuff again
368+
old_queue = self.queue.copy()
369+
old_patches = self.patches.copy()
370+
self.queue = []
371+
self.patches = {}
364372
self.loop.create_task(self.send_messages(self.connection))
365373

366374
res = await self._message({
@@ -375,6 +383,9 @@ async def _connect_job(self, host: str, port: int, id: str, token: str):
375383
if not res['result'] or res['result'] is not True:
376384
raise Exception('Job token invalid')
377385

386+
self.queue = old_queue
387+
self.patches = old_patches
388+
378389
self.connecting.set_result(True)
379390
if self.connections > 0:
380391
print("Deepkit: Reconnected.")

deepkit/deepkit_keras.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22
from __future__ import division
33

44
import base64
5+
import io
56
import math
67
import os
78
import sys
89
import time
910
from struct import pack
10-
from uuid import uuid4
1111

1212
import PIL.Image
1313
import numpy as np
14-
import six
1514

1615
if 'keras' in sys.modules:
1716
import keras
@@ -289,7 +288,7 @@ def make_image_from_dense(self, neurons):
289288
return img
290289

291290
def pil_image_to_jpeg(self, image):
292-
buffer = six.BytesIO()
291+
buffer = io.BytesIO()
293292

294293
image.save(buffer, format="JPEG", optimize=True, quality=70)
295294
return buffer.getvalue()

deepkit/pytorch.py

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import re
2+
from typing import Set, List
3+
4+
from deepkit.pytorch_graph import build_graph
5+
6+
blacklist_attributes = {'weight', 'dump_patches'}
7+
8+
9+
def extract_attributes(module):
10+
res = {}
11+
for attr in dir(module):
12+
if attr in blacklist_attributes: continue
13+
if attr.startswith('_'): continue
14+
val = getattr(module, attr)
15+
if not isinstance(val, (str, bool, int, float, list, tuple)):
16+
continue
17+
res[attr] = val
18+
19+
return res
20+
21+
22+
short_name_prog = re.compile(r'\[([a-zA-Z0-9]+)\]')
23+
is_variable = re.compile(r'/([a-zA-Z_]+\.[0-9]+)')
24+
25+
26+
def get_layer_id(name: str):
27+
"""
28+
Takes a name like 'ResNet/Conv2d[conv1]/1504' and converts it back to
29+
the name from named_modules method, e.g. conv1
30+
Examples
31+
1. 'ResNet/Sequential[layer1]/BasicBlock[1]/Conv2d[conv2]/1658'
32+
-> layer1.1.conv2
33+
2. 'ResNet/Sequential[layer2]/BasicBlock[0]/BatchNorm2d[bn1]/1714'
34+
-> layer2.0.bn1
35+
3. 'ResNet/Sequential[layer1]/BasicBlock[0]/input.4'
36+
-> layer1.0/input.4
37+
"""
38+
res = short_name_prog.findall(name)
39+
var = is_variable.search(name)
40+
if not res:
41+
return name
42+
if var:
43+
return '.'.join(res) + '/' + var.group(1)
44+
return '.'.join(res)
45+
46+
47+
def get_short_layer_id(name: str):
48+
"""
49+
Takes a name like 'ResNet/Conv2d[conv1]/1504' and converts it back to
50+
the name from named_modules method, e.g. conv1
51+
Examples
52+
1. 'ResNet/Sequential[layer1]/BasicBlock[1]/Conv2d[conv2]/1658'
53+
-> layer1.1.conv2
54+
2. 'ResNet/Sequential[layer2]/BasicBlock[0]/BatchNorm2d[bn1]/1714'
55+
-> layer2.0.bn1
56+
3. 'ResNet/Sequential[layer1]/BasicBlock[0]/input.4'
57+
-> layer1.0
58+
"""
59+
res = short_name_prog.findall(name)
60+
if not res:
61+
return name
62+
return '.'.join(res)
63+
64+
65+
def get_pytorch_graph(net, x):
66+
names_from_id = dict()
67+
nodes_from_id = dict()
68+
names_from_debug = dict()
69+
names_short_from_debug = dict()
70+
names_to_short = dict()
71+
72+
container_names = dict()
73+
generated_names_counter = dict()
74+
known_modules_map = dict()
75+
known_modules_name_map = dict()
76+
77+
tf_nodes = build_graph(net, x)
78+
79+
for name, module in net.named_modules():
80+
known_modules_map[module] = name
81+
known_modules_name_map[name] = module
82+
83+
def get_parent_names(name):
84+
t = ''
85+
for i in name.split('.')[:-1]:
86+
if t:
87+
t += '.'
88+
t += i
89+
yield t
90+
91+
def get_parent(name, go_up=1) -> str:
92+
return '.'.join(name.split('.')[:go_up * -1])
93+
94+
def gen_new_layer_id(name):
95+
if name in generated_names_counter:
96+
generated_names_counter[name] += 1
97+
else:
98+
generated_names_counter[name] = 1
99+
100+
return name + '-' + str(generated_names_counter[name])
101+
102+
for node in tf_nodes.values():
103+
# if node.kind == 'prim::Constant': continue
104+
105+
layer_id = get_layer_id(node.debugName)
106+
names_from_id[layer_id] = node.debugName
107+
nodes_from_id[layer_id] = node
108+
names_from_debug[node.debugName] = layer_id
109+
names_short_from_debug[node.debugName] = get_short_layer_id(node.debugName)
110+
names_to_short[layer_id] = names_short_from_debug[node.debugName]
111+
112+
edges = dict()
113+
114+
for node in tf_nodes.values():
115+
if node.debugName not in names_from_debug: continue
116+
layer_id = names_from_debug[node.debugName]
117+
short_layer_id = names_short_from_debug[node.debugName]
118+
119+
print(node.debugName, '=>', layer_id, short_layer_id, node.kind)
120+
for parent in get_parent_names(layer_id):
121+
container_names[parent] = True
122+
123+
for input in node.inputs:
124+
if input in names_from_debug and layer_id != names_from_debug[input] \
125+
and short_layer_id != names_from_debug[input]:
126+
print(' outgoing', names_from_debug[input], names_short_from_debug[input], input)
127+
# this node points out of itself, so create an edge
128+
edge_to = names_from_debug[input]
129+
130+
if layer_id in edges:
131+
edges[layer_id].add(edge_to)
132+
else:
133+
edges[layer_id] = set([edge_to])
134+
135+
def resolve_edges_to_known_layer(from_layer: str, inputs: Set[str]) -> List[str]:
136+
new_inputs = set()
137+
short_name = names_to_short[from_layer] if from_layer in names_to_short else None
138+
parent_name = get_parent(short_name) if short_name else None
139+
140+
# parent_layer = get_parent(from_layer)
141+
for input in inputs:
142+
input_short_name = names_to_short[input] if input in names_to_short else None
143+
144+
# we skip connection where even the 2. parent is not the same or a child of from_layer
145+
# we could make this configurable
146+
second_parent = get_parent(input_short_name, 2)
147+
if second_parent and short_name and not short_name.startswith(second_parent):
148+
continue
149+
150+
if input_short_name and short_name and short_name != input_short_name and input_short_name in known_modules_name_map:
151+
if not parent_name or (parent_name != input_short_name):
152+
new_inputs.add(input_short_name)
153+
continue
154+
155+
if input in edges:
156+
for i in resolve_edges_to_known_layer(from_layer, edges[input]):
157+
new_inputs.add(i)
158+
else:
159+
# we let it as is
160+
new_inputs.add(input)
161+
162+
return list(new_inputs)
163+
164+
edges_resolved = dict()
165+
shapes = dict()
166+
short_name_to_id = dict()
167+
168+
# we resolve the edges only from known layers
169+
for [name, inputs] in edges.items():
170+
# first name=layer2.0/input.1 => layer2.0
171+
short_name = name
172+
if name in names_to_short:
173+
short_name = names_to_short[name]
174+
175+
if short_name not in known_modules_name_map: continue
176+
# if short_name in edges_resolved: continue
177+
178+
shapes[short_name] = nodes_from_id[name].tensor_size
179+
short_name_to_id[short_name] = name
180+
edges_resolved[short_name] = resolve_edges_to_known_layer(name, inputs)
181+
182+
deepkit_nodes = []
183+
184+
for [name, inputs] in edges_resolved.items():
185+
module = known_modules_name_map[name]
186+
node = {
187+
'id': name,
188+
'label': name,
189+
'type': type(module).__name__,
190+
'input': inputs,
191+
'attributes': extract_attributes(module),
192+
'internalInputs': list(edges[short_name_to_id[name]]),
193+
'shape': shapes[name]
194+
}
195+
deepkit_nodes.append(node)
196+
197+
graph = {
198+
'nodes': deepkit_nodes
199+
}
200+
201+
return graph

0 commit comments

Comments
 (0)