Skip to content

Commit 3a33276

Browse files
saikonenobgibson
andauthored
fix: DAG rendering from code package (Netflix#290)
* inherit DAGNode from metaflow client instead of reimplementing it. * inherit most of FlowGraph functionality from metaflow client as well, overwriting only when necessary. * draft unit test for legacy dag parsing * first passing unit test for custom flowgraph * add more unit tests to cover custom flowgraph behaviour * revert client FlowGraph inheritance * codestyles * fix compatibility with imported DAGNode Co-authored-by: Brendan Gibson <brendan@outerbounds.co>
1 parent 3953b86 commit 3a33276

File tree

2 files changed

+291
-118
lines changed

2 files changed

+291
-118
lines changed

services/ui_backend_service/data/cache/custom_flowgraph.py

Lines changed: 16 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,9 @@
11
import ast
2+
from metaflow.graph import deindent_docstring, DAGNode
23

3-
4-
class DAGNode(object):
5-
def __init__(self, func_ast, decos, doc):
6-
self.name = func_ast.name
7-
self.func_lineno = func_ast.lineno
8-
self.decorators = decos
9-
self.doc = doc.rstrip()
10-
11-
# these attributes are populated by _parse
12-
self.tail_next_lineno = 0
13-
self.type = None
14-
self.out_funcs = []
15-
self.has_tail_next = False
16-
self.invalid_tail_next = False
17-
self.num_args = 0
18-
self.condition = None
19-
self.foreach_param = None
20-
self._parse(func_ast)
21-
22-
# these attributes are populated by _traverse_graph
23-
self.in_funcs = set()
24-
self.split_parents = []
25-
self.matching_join = None
26-
27-
# these attributes are populated by _postprocess
28-
self.is_inside_foreach = False
29-
30-
def _expr_str(self, expr):
31-
return '%s.%s' % (expr.value.id, expr.attr)
32-
33-
def _parse(self, func_ast):
34-
35-
self.num_args = len(func_ast.args.args)
36-
tail = func_ast.body[-1]
37-
38-
# end doesn't need a transition
39-
if self.name == 'end':
40-
# TYPE: end
41-
self.type = 'end'
42-
43-
# ensure that the tail an expression
44-
if not isinstance(tail, ast.Expr):
45-
return
46-
47-
# determine the type of self.next transition
48-
try:
49-
if not self._expr_str(tail.value.func) == 'self.next':
50-
return
51-
52-
self.has_tail_next = True
53-
self.invalid_tail_next = True
54-
self.tail_next_lineno = tail.lineno
55-
self.out_funcs = [e.attr for e in tail.value.args]
56-
keywords = dict((k.arg, k.value.s) for k in tail.value.keywords)
57-
58-
if len(keywords) == 1:
59-
if 'foreach' in keywords:
60-
# TYPE: foreach
61-
self.type = 'foreach'
62-
if len(self.out_funcs) == 1:
63-
self.foreach_param = keywords['foreach']
64-
self.invalid_tail_next = False
65-
elif 'condition' in keywords:
66-
# TYPE: split-or
67-
self.type = 'split-or'
68-
if len(self.out_funcs) == 2:
69-
self.condition = keywords['condition']
70-
self.invalid_tail_next = False
71-
elif len(keywords) == 0:
72-
if len(self.out_funcs) > 1:
73-
# TYPE: split-and
74-
self.type = 'split-and'
75-
self.invalid_tail_next = False
76-
elif len(self.out_funcs) == 1:
77-
# TYPE: linear
78-
if self.num_args > 1:
79-
self.type = 'join'
80-
else:
81-
self.type = 'linear'
82-
self.invalid_tail_next = False
83-
84-
except AttributeError:
85-
return
86-
87-
def __str__(self):
88-
return """
89-
*[{0.name} {0.type} (line {0.func_lineno})]*
90-
in_funcs={in_funcs}
91-
split_parents={parents}
92-
matching_join={matching_join}
93-
is_inside_foreach={is_inside_foreach}
94-
decorators={decos}
95-
num_args={0.num_args}
96-
has_tail_next={0.has_tail_next} (line {0.tail_next_lineno})
97-
invalid_tail_next={0.invalid_tail_next}
98-
condition={0.condition}
99-
foreach_param={0.foreach_param}
100-
-> {out}"""\
101-
.format(self,
102-
matching_join=self.matching_join and '[%s]' % self.matching_join,
103-
is_inside_foreach=self.is_inside_foreach,
104-
in_funcs=', '.join('[%s]' % x for x in self.in_funcs),
105-
parents=', '.join('[%s]' % x for x in self.split_parents),
106-
decos=' | '.join(map(str, self.decorators)),
107-
out=', '.join('[%s]' % x for x in self.out_funcs))
4+
# NOTE: This is a custom implementation of the FlowGraph class from the Metaflow client
5+
# which can parse a graph out of a flow_name and a source code string, instead of relying on
6+
# importing the source code as a module.
1087

1098

1109
class StepVisitor(ast.NodeVisitor):
@@ -141,7 +40,7 @@ def _flow(n):
14140
[root] = list(filter(_flow, ast.parse(source).body))
14241
self.name = root.name
14342
doc = ast.get_docstring(root)
144-
self.doc = doc if doc else ''
43+
self.doc = deindent_docstring(doc) if doc else ''
14544
nodes = {}
14645
StepVisitor(nodes).visit(root)
14746
return nodes
@@ -151,20 +50,18 @@ def _postprocess(self):
15150
# has is_inside_foreach=True *unless* all of those foreaches
15251
# are joined by the node
15352
for node in self.nodes.values():
154-
foreaches = [p for p in node.split_parents
155-
if self.nodes[p].type == 'foreach']
156-
if [f for f in foreaches
157-
if self.nodes[f].matching_join != node.name]:
53+
foreaches = [
54+
p for p in node.split_parents if self.nodes[p].type == "foreach"
55+
]
56+
if [f for f in foreaches if self.nodes[f].matching_join != node.name]:
15857
node.is_inside_foreach = True
15958

16059
def _traverse_graph(self):
161-
16260
def traverse(node, seen, split_parents):
163-
164-
if node.type in ('split-or', 'split-and', 'foreach'):
61+
if node.type in ("split", "foreach"):
16562
node.split_parents = split_parents
16663
split_parents = split_parents + [node.name]
167-
elif node.type == 'join':
64+
elif node.type == "join":
16865
# ignore joins without splits
16966
if split_parents:
17067
self[split_parents[-1]].matching_join = node.name
@@ -182,8 +79,8 @@ def traverse(node, seen, split_parents):
18279
child.in_funcs.add(node.name)
18380
traverse(child, seen + [n], split_parents)
18481

185-
if 'start' in self:
186-
traverse(self['start'], [], [])
82+
if "start" in self:
83+
traverse(self["start"], [], [])
18784

18885
# fix the order of in_funcs
18986
for node in self.nodes.values():
@@ -199,8 +96,9 @@ def __iter__(self):
19996
return iter(self.nodes.values())
20097

20198
def __str__(self):
202-
return '\n'.join(str(n) for _, n in sorted((n.func_lineno, n)
203-
for n in self.nodes.values()))
99+
return "\n".join(
100+
str(n) for _, n in sorted((n.func_lineno, n) for n in self.nodes.values())
101+
)
204102

205103
def output_steps(self):
206104

0 commit comments

Comments
 (0)