@@ -265,26 +265,6 @@ async def ensure_subcache_for(self, node_id, children_ids):
265265 assert cache is not None
266266 return await cache ._ensure_subcache (node_id , children_ids )
267267
268- class NullCache :
269-
270- async def set_prompt (self , dynprompt , node_ids , is_changed_cache ):
271- pass
272-
273- def all_node_ids (self ):
274- return []
275-
276- def clean_unused (self ):
277- pass
278-
279- def get (self , node_id ):
280- return None
281-
282- def set (self , node_id , value ):
283- pass
284-
285- async def ensure_subcache_for (self , node_id , children_ids ):
286- return self
287-
288268class LRUCache (BasicCache ):
289269 def __init__ (self , key_class , max_size = 100 ):
290270 super ().__init__ (key_class )
@@ -336,3 +316,157 @@ async def ensure_subcache_for(self, node_id, children_ids):
336316 self ._mark_used (child_id )
337317 self .children [cache_key ].append (self .cache_key_set .get_data_key (child_id ))
338318 return self
319+
320+
321+ class DependencyAwareCache (BasicCache ):
322+ """
323+ A cache implementation that tracks dependencies between nodes and manages
324+ their execution and caching accordingly. It extends the BasicCache class.
325+ Nodes are removed from this cache once all of their descendants have been
326+ executed.
327+ """
328+
329+ def __init__ (self , key_class ):
330+ """
331+ Initialize the DependencyAwareCache.
332+
333+ Args:
334+ key_class: The class used for generating cache keys.
335+ """
336+ super ().__init__ (key_class )
337+ self .descendants = {} # Maps node_id -> set of descendant node_ids
338+ self .ancestors = {} # Maps node_id -> set of ancestor node_ids
339+ self .executed_nodes = set () # Tracks nodes that have been executed
340+
341+ async def set_prompt (self , dynprompt , node_ids , is_changed_cache ):
342+ """
343+ Clear the entire cache and rebuild the dependency graph.
344+
345+ Args:
346+ dynprompt: The dynamic prompt object containing node information.
347+ node_ids: List of node IDs to initialize the cache for.
348+ is_changed_cache: Flag indicating if the cache has changed.
349+ """
350+ # Clear all existing cache data
351+ self .cache .clear ()
352+ self .subcaches .clear ()
353+ self .descendants .clear ()
354+ self .ancestors .clear ()
355+ self .executed_nodes .clear ()
356+
357+ # Call the parent method to initialize the cache with the new prompt
358+ await super ().set_prompt (dynprompt , node_ids , is_changed_cache )
359+
360+ # Rebuild the dependency graph
361+ self ._build_dependency_graph (dynprompt , node_ids )
362+
363+ def _build_dependency_graph (self , dynprompt , node_ids ):
364+ """
365+ Build the dependency graph for all nodes.
366+
367+ Args:
368+ dynprompt: The dynamic prompt object containing node information.
369+ node_ids: List of node IDs to build the graph for.
370+ """
371+ self .descendants .clear ()
372+ self .ancestors .clear ()
373+ for node_id in node_ids :
374+ self .descendants [node_id ] = set ()
375+ self .ancestors [node_id ] = set ()
376+
377+ for node_id in node_ids :
378+ inputs = dynprompt .get_node (node_id )["inputs" ]
379+ for input_data in inputs .values ():
380+ if is_link (input_data ): # Check if the input is a link to another node
381+ ancestor_id = input_data [0 ]
382+ self .descendants [ancestor_id ].add (node_id )
383+ self .ancestors [node_id ].add (ancestor_id )
384+
385+ def set (self , node_id , value ):
386+ """
387+ Mark a node as executed and store its value in the cache.
388+
389+ Args:
390+ node_id: The ID of the node to store.
391+ value: The value to store for the node.
392+ """
393+ self ._set_immediate (node_id , value )
394+ self .executed_nodes .add (node_id )
395+ self ._cleanup_ancestors (node_id )
396+
397+ def get (self , node_id ):
398+ """
399+ Retrieve the cached value for a node.
400+
401+ Args:
402+ node_id: The ID of the node to retrieve.
403+
404+ Returns:
405+ The cached value for the node.
406+ """
407+ return self ._get_immediate (node_id )
408+
409+ async def ensure_subcache_for (self , node_id , children_ids ):
410+ """
411+ Ensure a subcache exists for a node and update dependencies.
412+
413+ Args:
414+ node_id: The ID of the parent node.
415+ children_ids: List of child node IDs to associate with the parent node.
416+
417+ Returns:
418+ The subcache object for the node.
419+ """
420+ subcache = await super ()._ensure_subcache (node_id , children_ids )
421+ for child_id in children_ids :
422+ self .descendants [node_id ].add (child_id )
423+ self .ancestors [child_id ].add (node_id )
424+ return subcache
425+
426+ def _cleanup_ancestors (self , node_id ):
427+ """
428+ Check if ancestors of a node can be removed from the cache.
429+
430+ Args:
431+ node_id: The ID of the node whose ancestors are to be checked.
432+ """
433+ for ancestor_id in self .ancestors .get (node_id , []):
434+ if ancestor_id in self .executed_nodes :
435+ # Remove ancestor if all its descendants have been executed
436+ if all (descendant in self .executed_nodes for descendant in self .descendants [ancestor_id ]):
437+ self ._remove_node (ancestor_id )
438+
439+ def _remove_node (self , node_id ):
440+ """
441+ Remove a node from the cache.
442+
443+ Args:
444+ node_id: The ID of the node to remove.
445+ """
446+ cache_key = self .cache_key_set .get_data_key (node_id )
447+ if cache_key in self .cache :
448+ del self .cache [cache_key ]
449+ subcache_key = self .cache_key_set .get_subcache_key (node_id )
450+ if subcache_key in self .subcaches :
451+ del self .subcaches [subcache_key ]
452+
453+ def clean_unused (self ):
454+ """
455+ Clean up unused nodes. This is a no-op for this cache implementation.
456+ """
457+ pass
458+
459+ def recursive_debug_dump (self ):
460+ """
461+ Dump the cache and dependency graph for debugging.
462+
463+ Returns:
464+ A list containing the cache state and dependency graph.
465+ """
466+ result = super ().recursive_debug_dump ()
467+ result .append ({
468+ "descendants" : self .descendants ,
469+ "ancestors" : self .ancestors ,
470+ "executed_nodes" : list (self .executed_nodes ),
471+ })
472+ return result
0 commit comments