1
1
import ast
2
+ from metaflow .graph import deindent_docstring , DAGNode
2
3
3
-
4
- class DAGNode (object ):
5
- def __init__ (self , func_ast , decos , doc ):
6
- self .name = func_ast .name
7
- self .func_lineno = func_ast .lineno
8
- self .decorators = decos
9
- self .doc = doc .rstrip ()
10
-
11
- # these attributes are populated by _parse
12
- self .tail_next_lineno = 0
13
- self .type = None
14
- self .out_funcs = []
15
- self .has_tail_next = False
16
- self .invalid_tail_next = False
17
- self .num_args = 0
18
- self .condition = None
19
- self .foreach_param = None
20
- self ._parse (func_ast )
21
-
22
- # these attributes are populated by _traverse_graph
23
- self .in_funcs = set ()
24
- self .split_parents = []
25
- self .matching_join = None
26
-
27
- # these attributes are populated by _postprocess
28
- self .is_inside_foreach = False
29
-
30
- def _expr_str (self , expr ):
31
- return '%s.%s' % (expr .value .id , expr .attr )
32
-
33
- def _parse (self , func_ast ):
34
-
35
- self .num_args = len (func_ast .args .args )
36
- tail = func_ast .body [- 1 ]
37
-
38
- # end doesn't need a transition
39
- if self .name == 'end' :
40
- # TYPE: end
41
- self .type = 'end'
42
-
43
- # ensure that the tail an expression
44
- if not isinstance (tail , ast .Expr ):
45
- return
46
-
47
- # determine the type of self.next transition
48
- try :
49
- if not self ._expr_str (tail .value .func ) == 'self.next' :
50
- return
51
-
52
- self .has_tail_next = True
53
- self .invalid_tail_next = True
54
- self .tail_next_lineno = tail .lineno
55
- self .out_funcs = [e .attr for e in tail .value .args ]
56
- keywords = dict ((k .arg , k .value .s ) for k in tail .value .keywords )
57
-
58
- if len (keywords ) == 1 :
59
- if 'foreach' in keywords :
60
- # TYPE: foreach
61
- self .type = 'foreach'
62
- if len (self .out_funcs ) == 1 :
63
- self .foreach_param = keywords ['foreach' ]
64
- self .invalid_tail_next = False
65
- elif 'condition' in keywords :
66
- # TYPE: split-or
67
- self .type = 'split-or'
68
- if len (self .out_funcs ) == 2 :
69
- self .condition = keywords ['condition' ]
70
- self .invalid_tail_next = False
71
- elif len (keywords ) == 0 :
72
- if len (self .out_funcs ) > 1 :
73
- # TYPE: split-and
74
- self .type = 'split-and'
75
- self .invalid_tail_next = False
76
- elif len (self .out_funcs ) == 1 :
77
- # TYPE: linear
78
- if self .num_args > 1 :
79
- self .type = 'join'
80
- else :
81
- self .type = 'linear'
82
- self .invalid_tail_next = False
83
-
84
- except AttributeError :
85
- return
86
-
87
- def __str__ (self ):
88
- return """
89
- *[{0.name} {0.type} (line {0.func_lineno})]*
90
- in_funcs={in_funcs}
91
- split_parents={parents}
92
- matching_join={matching_join}
93
- is_inside_foreach={is_inside_foreach}
94
- decorators={decos}
95
- num_args={0.num_args}
96
- has_tail_next={0.has_tail_next} (line {0.tail_next_lineno})
97
- invalid_tail_next={0.invalid_tail_next}
98
- condition={0.condition}
99
- foreach_param={0.foreach_param}
100
- -> {out}""" \
101
- .format (self ,
102
- matching_join = self .matching_join and '[%s]' % self .matching_join ,
103
- is_inside_foreach = self .is_inside_foreach ,
104
- in_funcs = ', ' .join ('[%s]' % x for x in self .in_funcs ),
105
- parents = ', ' .join ('[%s]' % x for x in self .split_parents ),
106
- decos = ' | ' .join (map (str , self .decorators )),
107
- out = ', ' .join ('[%s]' % x for x in self .out_funcs ))
4
+ # NOTE: This is a custom implementation of the FlowGraph class from the Metaflow client
5
+ # which can parse a graph out of a flow_name and a source code string, instead of relying on
6
+ # importing the source code as a module.
108
7
109
8
110
9
class StepVisitor (ast .NodeVisitor ):
@@ -141,7 +40,7 @@ def _flow(n):
141
40
[root ] = list (filter (_flow , ast .parse (source ).body ))
142
41
self .name = root .name
143
42
doc = ast .get_docstring (root )
144
- self .doc = doc if doc else ''
43
+ self .doc = deindent_docstring ( doc ) if doc else ''
145
44
nodes = {}
146
45
StepVisitor (nodes ).visit (root )
147
46
return nodes
@@ -151,20 +50,18 @@ def _postprocess(self):
151
50
# has is_inside_foreach=True *unless* all of those foreaches
152
51
# are joined by the node
153
52
for node in self .nodes .values ():
154
- foreaches = [p for p in node . split_parents
155
- if self .nodes [p ].type == ' foreach' ]
156
- if [ f for f in foreaches
157
- if self .nodes [f ].matching_join != node .name ]:
53
+ foreaches = [
54
+ p for p in node . split_parents if self .nodes [p ].type == " foreach"
55
+ ]
56
+ if [ f for f in foreaches if self .nodes [f ].matching_join != node .name ]:
158
57
node .is_inside_foreach = True
159
58
160
59
def _traverse_graph (self ):
161
-
162
60
def traverse (node , seen , split_parents ):
163
-
164
- if node .type in ('split-or' , 'split-and' , 'foreach' ):
61
+ if node .type in ("split" , "foreach" ):
165
62
node .split_parents = split_parents
166
63
split_parents = split_parents + [node .name ]
167
- elif node .type == ' join' :
64
+ elif node .type == " join" :
168
65
# ignore joins without splits
169
66
if split_parents :
170
67
self [split_parents [- 1 ]].matching_join = node .name
@@ -182,8 +79,8 @@ def traverse(node, seen, split_parents):
182
79
child .in_funcs .add (node .name )
183
80
traverse (child , seen + [n ], split_parents )
184
81
185
- if ' start' in self :
186
- traverse (self [' start' ], [], [])
82
+ if " start" in self :
83
+ traverse (self [" start" ], [], [])
187
84
188
85
# fix the order of in_funcs
189
86
for node in self .nodes .values ():
@@ -199,8 +96,9 @@ def __iter__(self):
199
96
return iter (self .nodes .values ())
200
97
201
98
def __str__ (self ):
202
- return '\n ' .join (str (n ) for _ , n in sorted ((n .func_lineno , n )
203
- for n in self .nodes .values ()))
99
+ return "\n " .join (
100
+ str (n ) for _ , n in sorted ((n .func_lineno , n ) for n in self .nodes .values ())
101
+ )
204
102
205
103
def output_steps (self ):
206
104
0 commit comments