Skip to content

feature/add graph #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 169 additions & 124 deletions visualdl/server/graph.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import json

from google.protobuf.json_format import MessageToJson

import onnx
import graphviz_graph as gg
from PIL import Image


def debug_print(json_obj):
print(json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': ')))
print(json.dumps(
json_obj, sort_keys=True, indent=4, separators=(',', ': ')))


def reorganize_inout(json_obj, key):
Expand Down Expand Up @@ -78,15 +82,16 @@ def get_links(model_json):
name = input['name']
for node in model_json['node']:
if name in node['input']:
links.append({'source': name,
"target": node['name']})
links.append({'source': name, "target": node['name']})

for source_node in model_json['node']:
for output in source_node['output']:
for target_node in model_json['node']:
if output in target_node['input']:
links.append({'source': source_node['name'],
'target': target_node['name']})
links.append({
'source': source_node['name'],
'target': target_node['name']
})

return links

Expand Down Expand Up @@ -189,8 +194,6 @@ def get_level_to_all(node_links, model_json):
level_to_nodes[level] = list()
level_to_nodes[level].append(idx)
# debug_print(level_to_nodes)


"""
input_to_level {idx -> level}
level_to_inputs {level -> [input1, input2]}
Expand Down Expand Up @@ -231,7 +234,8 @@ def get_level_to_all(node_links, model_json):
if out_level not in output_to_level:
output_to_level[out_idx] = out_level
else:
raise Exception("output " + out_name + "have multiple source")
raise Exception(
"output " + out_name + "have multiple source")
level_to_outputs = dict()
for out_idx in output_to_level:
level = output_to_level[out_idx]
Expand All @@ -243,7 +247,12 @@ def get_level_to_all(node_links, model_json):

def init_level(level):
if level not in level_to_all:
level_to_all[level] = {'nodes': list(), 'inputs': list(), 'outputs': list()}
level_to_all[level] = {
'nodes': list(),
'inputs': list(),
'outputs': list()
}

# merge all levels
for level in level_to_nodes:
init_level(level)
Expand Down Expand Up @@ -321,116 +330,6 @@ def add_edges(json_obj):
return json_obj


def transform_for_echars(model_json):
opItemStyle = {
"normal": {
"color": '#d95f02'
}
}

paraterItemStyle = {
"normal": {
"color": '#1b9e77'
}
};

paraSymbolSize = [12, 6]
paraSymbol = 'rect'
opSymbolSize = [5, 5]

option = {
"title": {
"text": 'Default Graph Name'
},
"tooltip": {
"show": False
},
"animationDurationUpdate": 1500,
"animationEasingUpdate": 'quinticInOut',
"series": [
{
"type": "graph",
"layout": "none",
"symbolSize": 8,
"roam": True,
"label": {
"normal": {
"show": True,
"color": 'black'
}
},
"edgeSymbol": ['none', 'arrow'],
"edgeSymbolSize": [0, 10],
"edgeLabel": {
"normal": {
"textStyle": {
"fontSize": 20
}
}
},
"lineStyle": {
"normal": {
"opacity": 0.9,
"width": 2,
"curveness": 0
}
},
"data": [],
"links": []
}
]
}

option['title']['text'] = model_json['name']

rename_model(model_json)
node_links = get_node_links(model_json)
add_level_to_node_links(node_links)
level_to_all = get_level_to_all(node_links, model_json)
node_to_coordinate, input_to_coordinate, output_to_coordinate = level_to_coordinate(level_to_all)

inputs = model_json['input']
nodes = model_json['node']
outputs = model_json['output']

echars_data = list()

for in_idx in range(len(inputs)):
input = inputs[in_idx]
data = dict()
data['name'] = input['name']
data['x'] = input_to_coordinate[in_idx]['x']
data['y'] = input_to_coordinate[in_idx]['y']
data['symbol'] = paraSymbol
data['itemStyle'] = paraterItemStyle
data['symbolSize'] = paraSymbolSize
echars_data.append(data)
for node_idx in range(len(nodes)):
node = nodes[node_idx]
data = dict()
data['name'] = node['name']
data['x'] = node_to_coordinate[node_idx]['x']
data['y'] = node_to_coordinate[node_idx]['y']
data['itemStyle'] = opItemStyle
data['symbolSize'] = opSymbolSize
echars_data.append(data)
for out_idx in range(len(outputs)):
output = outputs[out_idx]
data = dict()
data['name'] = output['name']
data['x'] = output_to_coordinate[out_idx]['x']
data['y'] = output_to_coordinate[out_idx]['y']
data['symbol'] = paraSymbol
data['itemStyle'] = paraterItemStyle
data['symbolSize'] = paraSymbolSize
echars_data.append(data)

option['series'][0]['data'] = echars_data
option['series'][0]['links'] = get_links(model_json)

return option


def to_IR_json(model_pb_path):
model = onnx.load(model_pb_path)
graph = model.graph
Expand All @@ -446,14 +345,160 @@ def to_IR_json(model_pb_path):

def load_model(model_pb_path):
model_json = to_IR_json(model_pb_path)
options = transform_for_echars(model_json)
return options
model_json = add_edges(model_json)
return model_json


class GraphPreviewGenerator(object):
def __init__(self, model_json):
#self.model = json.loads(model_json)
self.model = model_json
# init graphviz graph
self.graph = gg.Graph(
self.model['name'],
layout="dot",
#resolution=200,
concentrate="true",
# rankdir="LR"
rankdir="TB",
)

self.op_rank = self.graph.rank_group('same', 2)
self.param_rank = self.graph.rank_group('same', 1)
self.arg_rank = self.graph.rank_group('same', 0)

def __call__(self, path='temp.dot'):
self.nodes = {}
self.params = set()
self.ops = set()
self.args = set()

for item in self.model['input'] + self.model['output']:
node = self.add_param(**item)
print 'name', item['name']
self.nodes[item['name']] = node
self.params.add(item['name'])

for id, item in enumerate(self.model['node']):
node = self.add_op(**item)
name = "node_" + str(id)
print 'name', name
self.nodes[name] = node
self.ops.add(name)

for item in self.model['edges']:
source = item['source']
target = item['target']

if source not in self.nodes:
self.nodes[source] = self.add_arg(source)
self.args.add(source)
if target not in self.nodes:
self.nodes[target] = self.add_arg(target)
self.args.add(target)

if source in self.args or target in self.args:
edge = self.add_edge(
style="dashed,bold", color="#aaaaaa", **item)
else:
edge = self.add_edge(style="bold", color="#aaaaaa", **item)

self.graph.display(path)

def add_param(self, name, data_type, shape):
label = '\n'.join([
'<<table cellpadding="5">',
' <tr>',
' <td bgcolor="#eeeeee">',
name,
' </td>'
' </tr>',
' <tr>',
' <td>',
data_type,
' </td>'
' </tr>',
' <tr>',
' <td>',
'[%s]' % 'x'.join(shape),
' </td>'
' </tr>',
'</table>>',
])
return self.graph.node(
label,
prefix="param",
shape="none",
# rank=self.param_rank,
style="rounded,filled,bold",
width="1.3",
#color="#ffa0a0",
color="#8cc7ff",
fontname="Arial")

def add_op(self, opType, **kwargs):
return self.graph.node(
gg.crepr(opType),
# rank=self.op_rank,
prefix="op",
shape="box",
style="rounded, filled, bold",
fillcolor="#8cc7cd",
#fillcolor="#8cc7ff",
fontname="Arial",
width="1.3",
height="0.84",
)

def add_arg(self, name):
return self.graph.node(
gg.crepr(name),
prefix="arg",
# rank=self.arg_rank,
shape="box",
style="rounded,filled,bold",
fontname="Arial",
color="grey")

def add_edge(self, source, target, label, **kwargs):
source = self.nodes[source]
target = self.nodes[target]
return self.graph.edge(source, target, **kwargs)


def draw_graph(model_pb_path, image_dir):
json_str = load_model(model_pb_path)
best_image = None
min_width = None
for i in range(10):
# randomly generate dot images and select the one with minimum width.
g = GraphPreviewGenerator(json_str)
dot_path = os.path.join(image_dir, "temp-%d.dot" % i)
image_path = os.path.join(image_dir, "temp-%d.jpg" % i)
g(dot_path)

try:
im = Image.open(image_path)
if min_width is None or im.size[0] < min_width:
min_width = im.size
best_image = image_path
except:
pass
return best_image


if __name__ == '__main__':
import os
import sys
current_path = os.path.abspath(os.path.dirname(sys.argv[0]))
# json_str = load_model(current_path + "/mock/inception_v1_model.pb")
json_str = load_model(current_path + "/mock/squeezenet_model.pb")
print(json_str)
json_str = load_model(current_path + "/mock/inception_v1_model.pb")
#json_str = load_model(current_path + "/mock/squeezenet_model.pb")
# json_str = load_model('./mock/shufflenet/model.pb')
debug_print(json_str)
assert json_str

g = GraphPreviewGenerator(json_str)
g('./temp.dot')
# for i in range(10):
# g = GraphPreviewGenerator(json_str)
# g('./temp-%d.dot' % i)
Loading