@@ -59,65 +59,127 @@ def __init__(self, *chain):
59
59
self .edges = {BEGIN : set ()}
60
60
self .named = {}
61
61
self .nodes = []
62
- self .add_chain (* chain )
62
+ if len (chain ):
63
+ self .add_chain (* chain )
63
64
64
65
def __iter__ (self ):
65
66
yield from self .nodes
66
67
67
68
def __len__ (self ):
68
- """ Node count.
69
+ """Node count.
69
70
"""
70
71
return len (self .nodes )
71
72
72
73
def __getitem__ (self , key ):
73
74
return self .nodes [key ]
74
75
75
76
def get_cursor (self , ref = BEGIN ):
76
- return GraphCursor (self , last = self ._resolve_index (ref ))
77
+ return GraphCursor (self , last = self .index_of (ref ))
77
78
78
- def outputs_of (self , idx , create = False ):
79
- """ Get a set of the outputs for a given node index.
79
+ def index_of (self , mixed ):
80
80
"""
81
- if create and not idx in self .edges :
82
- self .edges [idx ] = set ()
83
- return self .edges [idx ]
81
+ Find the index based on various strategies for a node, probably an input or output of chain. Supported
82
+ inputs are indexes, node values or names.
83
+
84
+ """
85
+ if mixed is None :
86
+ return None
87
+
88
+ if type (mixed ) is int or mixed in self .edges :
89
+ return mixed
90
+
91
+ if isinstance (mixed , str ) and mixed in self .named :
92
+ return self .named [mixed ]
93
+
94
+ if mixed in self .nodes :
95
+ return self .nodes .index (mixed )
84
96
85
- def add_node (self , c ):
86
- """ Add a node without connections in this graph and returns its index.
97
+ raise ValueError ("Cannot find node matching {!r}." .format (mixed ))
98
+
99
+ def outputs_of (self , idx_or_node , create = False ):
100
+ """Get a set of the outputs for a given node, node index or name.
101
+ """
102
+ idx_or_node = self .index_of (idx_or_node )
103
+
104
+ if create and not idx_or_node in self .edges :
105
+ self .edges [idx_or_node ] = set ()
106
+ return self .edges [idx_or_node ]
107
+
108
+ def add_node (self , c , * , _name = None ):
109
+ """Add a node without connections in this graph and returns its index.
110
+ If _name is specified, name this node (string reference for further usage).
87
111
"""
88
112
idx = len (self .nodes )
89
113
self .edges [idx ] = set ()
90
114
self .nodes .append (c )
115
+
116
+ if _name :
117
+ if _name in self .named :
118
+ raise KeyError ("Duplicate name {!r} in graph." .format (_name ))
119
+ self .named [_name ] = idx
120
+
91
121
return idx
92
122
93
123
def add_chain (self , * nodes , _input = BEGIN , _output = None , _name = None ):
94
- """ Add a chain in this graph.
124
+ """Add `nodes` as a chain in this graph.
125
+
126
+ **Input rules**
127
+
128
+ * By default, this chain will be connected to `BEGIN`, a.k.a the special node that kickstarts transformations.
129
+ * If `_input` is set to `None`, then this chain won't receive any input unless you connect it manually to
130
+ something.
131
+ * If `_input` is something that can resolve to another node using `index_of` rules, then the chain will
132
+ receive the output stream of referenced node.
133
+
134
+ **Output rules**
135
+
136
+ * By default, this chain won't send its output anywhere. This is, most of the time, what you want.
137
+ * If `_output` is set to something (that can resolve to a node), then the last node in the chain will send its
138
+ outputs to the given node. This means you can provide an object, a name, or an index.
139
+
140
+ **Naming**
141
+
142
+ * If a `_name` is given, the first node in the chain will be named this way (same effect as providing a `_name`
143
+ to add_node).
144
+
145
+ **Special cases**
146
+
147
+ * You can use this method to connect two other chains (in fact, two nodes) by not giving any `nodes`, but
148
+ still providing values to `_input` and `_output`.
149
+
95
150
"""
96
- if len (nodes ):
97
- _input = self ._resolve_index (_input )
98
- _output = self ._resolve_index (_output )
99
- _first = None
100
- _last = None
101
-
102
- for i , node in enumerate (nodes ):
103
- _last = self .add_node (node )
104
- if not i and _name :
105
- if _name in self .named :
106
- raise KeyError ("Duplicate name {!r} in graph." .format (_name ))
107
- self .named [_name ] = _last
108
- if _first is None :
109
- _first = _last
110
- self .outputs_of (_input , create = True ).add (_last )
111
- _input = _last
112
-
113
- if _output is not None :
114
- self .outputs_of (_input , create = True ).add (_output )
115
-
116
- if hasattr (self , "_topologcally_sorted_indexes_cache" ):
117
- del self ._topologcally_sorted_indexes_cache
118
-
119
- return GraphRange (self , _first , _last )
120
- return GraphRange (self , None , None )
151
+ _input = self .index_of (_input )
152
+ _output = self .index_of (_output )
153
+ _first = None
154
+ _last = None
155
+
156
+ # Sanity checks.
157
+ if not len (nodes ):
158
+ if _input is None or _output is None :
159
+ raise ValueError (
160
+ "Using add_chain(...) without nodes is only possible if you provide both _input and _output values."
161
+ )
162
+
163
+ if _name is not None :
164
+ raise RuntimeError ("Using add_chain(...) without nodes does not allow to use the _name parameter." )
165
+
166
+ for i , node in enumerate (nodes ):
167
+ _last = self .add_node (node , _name = _name if not i else None )
168
+
169
+ if _first is None :
170
+ _first = _last
171
+
172
+ self .outputs_of (_input , create = True ).add (_last )
173
+
174
+ _input = _last
175
+
176
+ if _output is not None :
177
+ self .outputs_of (_input , create = True ).add (_output )
178
+
179
+ if hasattr (self , "_topologcally_sorted_indexes_cache" ):
180
+ del self ._topologcally_sorted_indexes_cache
181
+
182
+ return GraphRange (self , _first , _last )
121
183
122
184
def copy (self ):
123
185
g = Graph ()
@@ -191,26 +253,6 @@ def _repr_html_(self):
191
253
except (ExecutableNotFound , FileNotFoundError ) as exc :
192
254
return "<strong>{}</strong>: {}" .format (type (exc ).__name__ , str (exc ))
193
255
194
- def _resolve_index (self , mixed ):
195
- """
196
- Find the index based on various strategies for a node, probably an input or output of chain. Supported
197
- inputs are indexes, node values or names.
198
-
199
- """
200
- if mixed is None :
201
- return None
202
-
203
- if type (mixed ) is int or mixed in self .edges :
204
- return mixed
205
-
206
- if isinstance (mixed , str ) and mixed in self .named :
207
- return self .named [mixed ]
208
-
209
- if mixed in self .nodes :
210
- return self .nodes .index (mixed )
211
-
212
- raise ValueError ("Cannot find node matching {!r}." .format (mixed ))
213
-
214
256
215
257
def _get_graphviz_node_id (graph , i ):
216
258
escaped_index = str (i )
0 commit comments