Skip to content

Commit

Permalink
add fall_layout
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonfreak committed Jul 25, 2016
1 parent f5a901e commit b423e8a
Showing 1 changed file with 53 additions and 6 deletions.
59 changes: 53 additions & 6 deletions ple.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,69 @@ def initRoot(featureNameList):
root.transform('init', newFeature)
return root

def _draw(G, root, edgeLabelDict):
def _draw(G, root, nodeLabelDict, edgeLabelDict):
nodeLabelDict[root.label] = root.name
for transform in root.transformList:
G.add_edge(root.label, transform.feature.label)
edgeLabelDict[(root.label, transform.feature.label)] = transform.label
_draw(G, transform.feature, edgeLabelDict)
_draw(G, transform.feature, nodeLabelDict, edgeLabelDict)

def _isCyclic(root, walked):
if root in walked:
return True
else:
walked.add(root)
for transform in root.transformList:
ret = _isCyclic(transform.feature, walked)
if ret:
return True
walked.remove(root)
return False

def fall_layout(root, x_space=1, y_space=1):
layout = {}
if _isCyclic(root, set()):
raise Exception('Graph is cyclic')

queue = [None, root]
nodeDict = {}
levelDict = {}
level = 0
while len(queue) > 0:
head = queue.pop()
if head is None:
if len(queue) > 0:
level += 1
queue.insert(0, None)
else:
if head in nodeDict:
levelDict[nodeDict[head]].remove(head)
nodeDict[head] = level
levelDict[level] = levelDict.get(level, []) + [head]
for transform in head.transformList:
queue.insert(0, transform.feature)

for level in levelDict.keys():
nodeList = levelDict[level]
n_nodes = len(nodeList)
offset = - n_nodes / 2
for i in range(n_nodes):
layout[nodeList[i].label] = (level * x_space, (i + offset) * y_space)

return layout

def draw(root):
G = nx.DiGraph()
nodeLabelDict = {}
edgeLabelDict = {}

_draw(G, root, edgeLabelDict)
pos=nx.spring_layout(G, iterations=150)
_draw(G, root, nodeLabelDict, edgeLabelDict)
# pos=nx.spring_layout(G, iterations=150)
pos = fall_layout(root)

nx.draw_networkx_nodes(G,pos,node_size=50, node_color="white")
nx.draw_networkx_nodes(G,pos,node_size=100, node_color="white")
nx.draw_networkx_edges(G,pos, width=1,alpha=0.5,edge_color='black')
nx.draw_networkx_labels(G,pos,font_size=10,font_family='sans-serif')
nx.draw_networkx_labels(G,pos,labels=nodeLabelDict, font_size=10,font_family='sans-serif')
nx.draw_networkx_edge_labels(G, pos, edgeLabelDict)

plt.show()

0 comments on commit b423e8a

Please sign in to comment.