Skip to content

Commit b7992f8

Browse files
Revert "execution: fold in dependency aware caching / Fix --cache-none with l…" (#10422)
This reverts commit b1467da.
1 parent 2c2aa40 commit b7992f8

File tree

5 files changed

+190
-101
lines changed

5 files changed

+190
-101
lines changed

comfy_execution/caching.py

Lines changed: 154 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -265,26 +265,6 @@ async def ensure_subcache_for(self, node_id, children_ids):
265265
assert cache is not None
266266
return await cache._ensure_subcache(node_id, children_ids)
267267

268-
class NullCache:
269-
270-
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
271-
pass
272-
273-
def all_node_ids(self):
274-
return []
275-
276-
def clean_unused(self):
277-
pass
278-
279-
def get(self, node_id):
280-
return None
281-
282-
def set(self, node_id, value):
283-
pass
284-
285-
async def ensure_subcache_for(self, node_id, children_ids):
286-
return self
287-
288268
class LRUCache(BasicCache):
289269
def __init__(self, key_class, max_size=100):
290270
super().__init__(key_class)
@@ -336,3 +316,157 @@ async def ensure_subcache_for(self, node_id, children_ids):
336316
self._mark_used(child_id)
337317
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
338318
return self
319+
320+
321+
class DependencyAwareCache(BasicCache):
322+
"""
323+
A cache implementation that tracks dependencies between nodes and manages
324+
their execution and caching accordingly. It extends the BasicCache class.
325+
Nodes are removed from this cache once all of their descendants have been
326+
executed.
327+
"""
328+
329+
def __init__(self, key_class):
330+
"""
331+
Initialize the DependencyAwareCache.
332+
333+
Args:
334+
key_class: The class used for generating cache keys.
335+
"""
336+
super().__init__(key_class)
337+
self.descendants = {} # Maps node_id -> set of descendant node_ids
338+
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
339+
self.executed_nodes = set() # Tracks nodes that have been executed
340+
341+
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
342+
"""
343+
Clear the entire cache and rebuild the dependency graph.
344+
345+
Args:
346+
dynprompt: The dynamic prompt object containing node information.
347+
node_ids: List of node IDs to initialize the cache for.
348+
is_changed_cache: Flag indicating if the cache has changed.
349+
"""
350+
# Clear all existing cache data
351+
self.cache.clear()
352+
self.subcaches.clear()
353+
self.descendants.clear()
354+
self.ancestors.clear()
355+
self.executed_nodes.clear()
356+
357+
# Call the parent method to initialize the cache with the new prompt
358+
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
359+
360+
# Rebuild the dependency graph
361+
self._build_dependency_graph(dynprompt, node_ids)
362+
363+
def _build_dependency_graph(self, dynprompt, node_ids):
364+
"""
365+
Build the dependency graph for all nodes.
366+
367+
Args:
368+
dynprompt: The dynamic prompt object containing node information.
369+
node_ids: List of node IDs to build the graph for.
370+
"""
371+
self.descendants.clear()
372+
self.ancestors.clear()
373+
for node_id in node_ids:
374+
self.descendants[node_id] = set()
375+
self.ancestors[node_id] = set()
376+
377+
for node_id in node_ids:
378+
inputs = dynprompt.get_node(node_id)["inputs"]
379+
for input_data in inputs.values():
380+
if is_link(input_data): # Check if the input is a link to another node
381+
ancestor_id = input_data[0]
382+
self.descendants[ancestor_id].add(node_id)
383+
self.ancestors[node_id].add(ancestor_id)
384+
385+
def set(self, node_id, value):
386+
"""
387+
Mark a node as executed and store its value in the cache.
388+
389+
Args:
390+
node_id: The ID of the node to store.
391+
value: The value to store for the node.
392+
"""
393+
self._set_immediate(node_id, value)
394+
self.executed_nodes.add(node_id)
395+
self._cleanup_ancestors(node_id)
396+
397+
def get(self, node_id):
398+
"""
399+
Retrieve the cached value for a node.
400+
401+
Args:
402+
node_id: The ID of the node to retrieve.
403+
404+
Returns:
405+
The cached value for the node.
406+
"""
407+
return self._get_immediate(node_id)
408+
409+
async def ensure_subcache_for(self, node_id, children_ids):
410+
"""
411+
Ensure a subcache exists for a node and update dependencies.
412+
413+
Args:
414+
node_id: The ID of the parent node.
415+
children_ids: List of child node IDs to associate with the parent node.
416+
417+
Returns:
418+
The subcache object for the node.
419+
"""
420+
subcache = await super()._ensure_subcache(node_id, children_ids)
421+
for child_id in children_ids:
422+
self.descendants[node_id].add(child_id)
423+
self.ancestors[child_id].add(node_id)
424+
return subcache
425+
426+
def _cleanup_ancestors(self, node_id):
427+
"""
428+
Check if ancestors of a node can be removed from the cache.
429+
430+
Args:
431+
node_id: The ID of the node whose ancestors are to be checked.
432+
"""
433+
for ancestor_id in self.ancestors.get(node_id, []):
434+
if ancestor_id in self.executed_nodes:
435+
# Remove ancestor if all its descendants have been executed
436+
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
437+
self._remove_node(ancestor_id)
438+
439+
def _remove_node(self, node_id):
440+
"""
441+
Remove a node from the cache.
442+
443+
Args:
444+
node_id: The ID of the node to remove.
445+
"""
446+
cache_key = self.cache_key_set.get_data_key(node_id)
447+
if cache_key in self.cache:
448+
del self.cache[cache_key]
449+
subcache_key = self.cache_key_set.get_subcache_key(node_id)
450+
if subcache_key in self.subcaches:
451+
del self.subcaches[subcache_key]
452+
453+
def clean_unused(self):
454+
"""
455+
Clean up unused nodes. This is a no-op for this cache implementation.
456+
"""
457+
pass
458+
459+
def recursive_debug_dump(self):
460+
"""
461+
Dump the cache and dependency graph for debugging.
462+
463+
Returns:
464+
A list containing the cache state and dependency graph.
465+
"""
466+
result = super().recursive_debug_dump()
467+
result.append({
468+
"descendants": self.descendants,
469+
"ancestors": self.ancestors,
470+
"executed_nodes": list(self.executed_nodes),
471+
})
472+
return result

comfy_execution/graph.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,8 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
153153
continue
154154
_, _, input_info = self.get_input_info(unique_id, input_name)
155155
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
156-
if (include_lazy or not is_lazy):
157-
if not self.is_cached(from_node_id):
158-
node_ids.append(from_node_id)
156+
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
157+
node_ids.append(from_node_id)
159158
links.append((from_node_id, from_socket, unique_id))
160159

161160
for link in links:
@@ -195,34 +194,10 @@ def __init__(self, dynprompt, output_cache):
195194
super().__init__(dynprompt)
196195
self.output_cache = output_cache
197196
self.staged_node_id = None
198-
self.execution_cache = {}
199-
self.execution_cache_listeners = {}
200197

201198
def is_cached(self, node_id):
202199
return self.output_cache.get(node_id) is not None
203200

204-
def cache_link(self, from_node_id, to_node_id):
205-
if not to_node_id in self.execution_cache:
206-
self.execution_cache[to_node_id] = {}
207-
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
208-
if not from_node_id in self.execution_cache_listeners:
209-
self.execution_cache_listeners[from_node_id] = set()
210-
self.execution_cache_listeners[from_node_id].add(to_node_id)
211-
212-
def get_output_cache(self, from_node_id, to_node_id):
213-
if not to_node_id in self.execution_cache:
214-
return None
215-
return self.execution_cache[to_node_id].get(from_node_id)
216-
217-
def cache_update(self, node_id, value):
218-
if node_id in self.execution_cache_listeners:
219-
for to_node_id in self.execution_cache_listeners[node_id]:
220-
self.execution_cache[to_node_id][node_id] = value
221-
222-
def add_strong_link(self, from_node_id, from_socket, to_node_id):
223-
super().add_strong_link(from_node_id, from_socket, to_node_id)
224-
self.cache_link(from_node_id, to_node_id)
225-
226201
async def stage_node_execution(self):
227202
assert self.staged_node_id is None
228203
if self.is_empty():
@@ -302,8 +277,6 @@ def unstage_node_execution(self):
302277
def complete_node_execution(self):
303278
node_id = self.staged_node_id
304279
self.pop_node(node_id)
305-
self.execution_cache.pop(node_id, None)
306-
self.execution_cache_listeners.pop(node_id, None)
307280
self.staged_node_id = None
308281

309282
def get_nodes_in_cycle(self):

execution.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
BasicCache,
1919
CacheKeySetID,
2020
CacheKeySetInputSignature,
21-
NullCache,
21+
DependencyAwareCache,
2222
HierarchicalCache,
2323
LRUCache,
2424
)
@@ -91,13 +91,13 @@ async def get(self, node_id):
9191
class CacheType(Enum):
9292
CLASSIC = 0
9393
LRU = 1
94-
NONE = 2
94+
DEPENDENCY_AWARE = 2
9595

9696

9797
class CacheSet:
9898
def __init__(self, cache_type=None, cache_size=None):
99-
if cache_type == CacheType.NONE:
100-
self.init_null_cache()
99+
if cache_type == CacheType.DEPENDENCY_AWARE:
100+
self.init_dependency_aware_cache()
101101
logging.info("Disabling intermediate node cache.")
102102
elif cache_type == CacheType.LRU:
103103
if cache_size is None:
@@ -120,12 +120,11 @@ def init_lru_cache(self, cache_size):
120120
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
121121
self.objects = HierarchicalCache(CacheKeySetID)
122122

123-
def init_null_cache(self):
124-
self.outputs = NullCache()
125-
#The UI cache is expected to be iterable at the end of each workflow
126-
#so it must cache at least a full workflow. Use Heirachical
127-
self.ui = HierarchicalCache(CacheKeySetInputSignature)
128-
self.objects = NullCache()
123+
# only hold cached items while the decendents have not executed
124+
def init_dependency_aware_cache(self):
125+
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
126+
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
127+
self.objects = DependencyAwareCache(CacheKeySetID)
129128

130129
def recursive_debug_dump(self):
131130
result = {
@@ -136,7 +135,7 @@ def recursive_debug_dump(self):
136135

137136
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
138137

139-
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
138+
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
140139
is_v3 = issubclass(class_def, _ComfyNodeInternal)
141140
if is_v3:
142141
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
@@ -154,10 +153,10 @@ def mark_missing():
154153
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
155154
input_unique_id = input_data[0]
156155
output_index = input_data[1]
157-
if execution_list is None:
156+
if outputs is None:
158157
mark_missing()
159158
continue # This might be a lazily-evaluated input
160-
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
159+
cached_output = outputs.get(input_unique_id)
161160
if cached_output is None:
162161
mark_missing()
163162
continue
@@ -406,7 +405,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
406405
cached_output = caches.ui.get(unique_id) or {}
407406
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
408407
get_progress_state().finish_progress(unique_id)
409-
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
410408
return (ExecutionResult.SUCCESS, None, None)
411409

412410
input_data_all = None
@@ -436,7 +434,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
436434
for r in result:
437435
if is_link(r):
438436
source_node, source_output = r[0], r[1]
439-
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
437+
node_output = caches.outputs.get(source_node)[source_output]
440438
for o in node_output:
441439
resolved_output.append(o)
442440

@@ -448,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
448446
has_subgraph = False
449447
else:
450448
get_progress_state().start_progress(unique_id)
451-
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
449+
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
452450
if server.client_id is not None:
453451
server.last_node_id = display_node_id
454452
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@@ -551,15 +549,11 @@ async def await_completion():
551549
subcache.clean_unused()
552550
for node_id in new_output_ids:
553551
execution_list.add_node(node_id)
554-
execution_list.cache_link(node_id, unique_id)
555552
for link in new_output_links:
556553
execution_list.add_strong_link(link[0], link[1], unique_id)
557554
pending_subgraph_results[unique_id] = cached_outputs
558555
return (ExecutionResult.PENDING, None, None)
559-
560556
caches.outputs.set(unique_id, output_data)
561-
execution_list.cache_update(unique_id, output_data)
562-
563557
except comfy.model_management.InterruptProcessingException as iex:
564558
logging.info("Processing interrupted")
565559

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def prompt_worker(q, server_instance):
173173
if args.cache_lru > 0:
174174
cache_type = execution.CacheType.LRU
175175
elif args.cache_none:
176-
cache_type = execution.CacheType.NONE
176+
cache_type = execution.CacheType.DEPENDENCY_AWARE
177177

178178
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
179179
last_gc_collect = 0

0 commit comments

Comments
 (0)