Skip to content

Commit 0e53c13

Browse files
committed
main: re-implement --cache-none as no cache at all
The execution list now tracks the dependency aware caching more correctly that the DependancyAwareCache. Change it to a cache that does nothing.
1 parent adb6450 commit 0e53c13

File tree

3 files changed

+29
-164
lines changed

3 files changed

+29
-164
lines changed

comfy_execution/caching.py

Lines changed: 20 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,26 @@ async def ensure_subcache_for(self, node_id, children_ids):
265265
assert cache is not None
266266
return await cache._ensure_subcache(node_id, children_ids)
267267

268+
class NullCache:
269+
270+
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
271+
pass
272+
273+
def all_node_ids(self):
274+
return []
275+
276+
def clean_unused(self):
277+
pass
278+
279+
def get(self, node_id):
280+
return None
281+
282+
def set(self, node_id, value):
283+
pass
284+
285+
async def ensure_subcache_for(self, node_id, children_ids):
286+
return self
287+
268288
class LRUCache(BasicCache):
269289
def __init__(self, key_class, max_size=100):
270290
super().__init__(key_class)
@@ -316,157 +336,3 @@ async def ensure_subcache_for(self, node_id, children_ids):
316336
self._mark_used(child_id)
317337
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
318338
return self
319-
320-
321-
class DependencyAwareCache(BasicCache):
322-
"""
323-
A cache implementation that tracks dependencies between nodes and manages
324-
their execution and caching accordingly. It extends the BasicCache class.
325-
Nodes are removed from this cache once all of their descendants have been
326-
executed.
327-
"""
328-
329-
def __init__(self, key_class):
330-
"""
331-
Initialize the DependencyAwareCache.
332-
333-
Args:
334-
key_class: The class used for generating cache keys.
335-
"""
336-
super().__init__(key_class)
337-
self.descendants = {} # Maps node_id -> set of descendant node_ids
338-
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
339-
self.executed_nodes = set() # Tracks nodes that have been executed
340-
341-
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
342-
"""
343-
Clear the entire cache and rebuild the dependency graph.
344-
345-
Args:
346-
dynprompt: The dynamic prompt object containing node information.
347-
node_ids: List of node IDs to initialize the cache for.
348-
is_changed_cache: Flag indicating if the cache has changed.
349-
"""
350-
# Clear all existing cache data
351-
self.cache.clear()
352-
self.subcaches.clear()
353-
self.descendants.clear()
354-
self.ancestors.clear()
355-
self.executed_nodes.clear()
356-
357-
# Call the parent method to initialize the cache with the new prompt
358-
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
359-
360-
# Rebuild the dependency graph
361-
self._build_dependency_graph(dynprompt, node_ids)
362-
363-
def _build_dependency_graph(self, dynprompt, node_ids):
364-
"""
365-
Build the dependency graph for all nodes.
366-
367-
Args:
368-
dynprompt: The dynamic prompt object containing node information.
369-
node_ids: List of node IDs to build the graph for.
370-
"""
371-
self.descendants.clear()
372-
self.ancestors.clear()
373-
for node_id in node_ids:
374-
self.descendants[node_id] = set()
375-
self.ancestors[node_id] = set()
376-
377-
for node_id in node_ids:
378-
inputs = dynprompt.get_node(node_id)["inputs"]
379-
for input_data in inputs.values():
380-
if is_link(input_data): # Check if the input is a link to another node
381-
ancestor_id = input_data[0]
382-
self.descendants[ancestor_id].add(node_id)
383-
self.ancestors[node_id].add(ancestor_id)
384-
385-
def set(self, node_id, value):
386-
"""
387-
Mark a node as executed and store its value in the cache.
388-
389-
Args:
390-
node_id: The ID of the node to store.
391-
value: The value to store for the node.
392-
"""
393-
self._set_immediate(node_id, value)
394-
self.executed_nodes.add(node_id)
395-
self._cleanup_ancestors(node_id)
396-
397-
def get(self, node_id):
398-
"""
399-
Retrieve the cached value for a node.
400-
401-
Args:
402-
node_id: The ID of the node to retrieve.
403-
404-
Returns:
405-
The cached value for the node.
406-
"""
407-
return self._get_immediate(node_id)
408-
409-
async def ensure_subcache_for(self, node_id, children_ids):
410-
"""
411-
Ensure a subcache exists for a node and update dependencies.
412-
413-
Args:
414-
node_id: The ID of the parent node.
415-
children_ids: List of child node IDs to associate with the parent node.
416-
417-
Returns:
418-
The subcache object for the node.
419-
"""
420-
subcache = await super()._ensure_subcache(node_id, children_ids)
421-
for child_id in children_ids:
422-
self.descendants[node_id].add(child_id)
423-
self.ancestors[child_id].add(node_id)
424-
return subcache
425-
426-
def _cleanup_ancestors(self, node_id):
427-
"""
428-
Check if ancestors of a node can be removed from the cache.
429-
430-
Args:
431-
node_id: The ID of the node whose ancestors are to be checked.
432-
"""
433-
for ancestor_id in self.ancestors.get(node_id, []):
434-
if ancestor_id in self.executed_nodes:
435-
# Remove ancestor if all its descendants have been executed
436-
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
437-
self._remove_node(ancestor_id)
438-
439-
def _remove_node(self, node_id):
440-
"""
441-
Remove a node from the cache.
442-
443-
Args:
444-
node_id: The ID of the node to remove.
445-
"""
446-
cache_key = self.cache_key_set.get_data_key(node_id)
447-
if cache_key in self.cache:
448-
del self.cache[cache_key]
449-
subcache_key = self.cache_key_set.get_subcache_key(node_id)
450-
if subcache_key in self.subcaches:
451-
del self.subcaches[subcache_key]
452-
453-
def clean_unused(self):
454-
"""
455-
Clean up unused nodes. This is a no-op for this cache implementation.
456-
"""
457-
pass
458-
459-
def recursive_debug_dump(self):
460-
"""
461-
Dump the cache and dependency graph for debugging.
462-
463-
Returns:
464-
A list containing the cache state and dependency graph.
465-
"""
466-
result = super().recursive_debug_dump()
467-
result.append({
468-
"descendants": self.descendants,
469-
"ancestors": self.ancestors,
470-
"executed_nodes": list(self.executed_nodes),
471-
})
472-
return result

execution.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
BasicCache,
1919
CacheKeySetID,
2020
CacheKeySetInputSignature,
21-
DependencyAwareCache,
21+
NullCache,
2222
HierarchicalCache,
2323
LRUCache,
2424
)
@@ -91,13 +91,13 @@ async def get(self, node_id):
9191
class CacheType(Enum):
9292
CLASSIC = 0
9393
LRU = 1
94-
DEPENDENCY_AWARE = 2
94+
NONE = 2
9595

9696

9797
class CacheSet:
9898
def __init__(self, cache_type=None, cache_size=None):
99-
if cache_type == CacheType.DEPENDENCY_AWARE:
100-
self.init_dependency_aware_cache()
99+
if cache_type == CacheType.NONE:
100+
self.init_null_cache()
101101
logging.info("Disabling intermediate node cache.")
102102
elif cache_type == CacheType.LRU:
103103
if cache_size is None:
@@ -120,11 +120,10 @@ def init_lru_cache(self, cache_size):
120120
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
121121
self.objects = HierarchicalCache(CacheKeySetID)
122122

123-
# only hold cached items while the decendents have not executed
124-
def init_dependency_aware_cache(self):
125-
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
126-
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
127-
self.objects = DependencyAwareCache(CacheKeySetID)
123+
def init_null_cache(self):
124+
self.outputs = NullCache()
125+
self.ui = NullCache()
126+
self.objects = NullCache()
128127

129128
def recursive_debug_dump(self):
130129
result = {

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def prompt_worker(q, server_instance):
173173
if args.cache_lru > 0:
174174
cache_type = execution.CacheType.LRU
175175
elif args.cache_none:
176-
cache_type = execution.CacheType.DEPENDENCY_AWARE
176+
cache_type = execution.CacheType.NONE
177177

178178
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
179179
last_gc_collect = 0

0 commit comments

Comments
 (0)