Skip to content

Commit 9a822bc

Browse files
authored
Jazzhaiku/stats (#8006)
## Summary - Modify stats reset to be on a per session basis, rather than a "full reset", to allow for parallel session execution - Add "aider" to gitignore ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions <!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 7722f47 + 5f12b91 commit 9a822bc

File tree

5 files changed

+94
-6
lines changed

5 files changed

+94
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,4 @@ installer/install.sh
188188
installer/update.bat
189189
installer/update.sh
190190
installer/InvokeAI-Installer/
191+
.aider*

invokeai/app/services/invocation_stats/invocation_stats_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def collect_stats(
6060
pass
6161

6262
@abstractmethod
63-
def reset_stats(self):
63+
def reset_stats(self, graph_execution_state_id: str) -> None:
6464
"""Reset all stored statistics."""
6565
pass
6666

invokeai/app/services/invocation_stats/invocation_stats_default.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st
7373
)
7474
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
7575

76-
def reset_stats(self):
77-
self._stats = {}
78-
self._cache_stats = {}
76+
def reset_stats(self, graph_execution_state_id: str) -> None:
77+
self._stats.pop(graph_execution_state_id, None)
78+
self._cache_stats.pop(graph_execution_state_id, None)
7979

8080
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
8181
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

invokeai/app/services/session_processor/session_processor_default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
210210
# we don't care about that - suppress the error.
211211
with suppress(GESStatsNotFoundError):
212212
self._services.performance_statistics.log_stats(queue_item.session.id)
213-
self._services.performance_statistics.reset_stats()
213+
self._services.performance_statistics.reset_stats(queue_item.session.id)
214214

215215
for callback in self._on_after_run_session_callbacks:
216216
callback(queue_item=queue_item)

invokeai/backend/model_manager/load/model_cache/model_cache.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import logging
33
import threading
44
import time
5+
from dataclasses import dataclass
56
from functools import wraps
67
from logging import Logger
7-
from typing import Any, Callable, Dict, List, Optional
8+
from typing import Any, Callable, Dict, List, Optional, Protocol
89

910
import psutil
1011
import torch
@@ -54,6 +55,39 @@ def wrapper(self, *args, **kwargs):
5455
return wrapper
5556

5657

58+
@dataclass
59+
class CacheEntrySnapshot:
60+
cache_key: str
61+
total_bytes: int
62+
current_vram_bytes: int
63+
64+
65+
class CacheMissCallback(Protocol):
66+
def __call__(
67+
self,
68+
model_key: str,
69+
cache_snapshot: dict[str, CacheEntrySnapshot],
70+
) -> None: ...
71+
72+
73+
class CacheHitCallback(Protocol):
74+
def __call__(
75+
self,
76+
model_key: str,
77+
cache_snapshot: dict[str, CacheEntrySnapshot],
78+
) -> None: ...
79+
80+
81+
class CacheModelsClearedCallback(Protocol):
82+
def __call__(
83+
self,
84+
models_cleared: int,
85+
bytes_requested: int,
86+
bytes_freed: int,
87+
cache_snapshot: dict[str, CacheEntrySnapshot],
88+
) -> None: ...
89+
90+
5791
class ModelCache:
5892
"""A cache for managing models in memory.
5993
@@ -144,6 +178,34 @@ def __init__(
144178
# - Requests to empty the cache from a separate thread
145179
self._lock = threading.RLock()
146180

181+
self._on_cache_hit_callbacks: set[CacheHitCallback] = set()
182+
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
183+
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
184+
185+
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
186+
self._on_cache_hit_callbacks.add(cb)
187+
188+
def unsubscribe() -> None:
189+
self._on_cache_hit_callbacks.discard(cb)
190+
191+
return unsubscribe
192+
193+
def on_cache_miss(self, cb: CacheHitCallback) -> Callable[[], None]:
194+
self._on_cache_miss_callbacks.add(cb)
195+
196+
def unsubscribe() -> None:
197+
self._on_cache_miss_callbacks.discard(cb)
198+
199+
return unsubscribe
200+
201+
def on_cache_models_cleared(self, cb: CacheModelsClearedCallback) -> Callable[[], None]:
202+
self._on_cache_models_cleared_callbacks.add(cb)
203+
204+
def unsubscribe() -> None:
205+
self._on_cache_models_cleared_callbacks.discard(cb)
206+
207+
return unsubscribe
208+
147209
@property
148210
@synchronized
149211
def stats(self) -> Optional[CacheStats]:
@@ -195,6 +257,20 @@ def put(self, key: str, model: AnyModel) -> None:
195257
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size / MB:.2f}MB)"
196258
)
197259

260+
@synchronized
261+
def _get_cache_snapshot(self) -> dict[str, CacheEntrySnapshot]:
262+
overview: dict[str, CacheEntrySnapshot] = {}
263+
for cache_key, cache_entry in self._cached_models.items():
264+
total_bytes = cache_entry.cached_model.total_bytes()
265+
current_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
266+
overview[cache_key] = CacheEntrySnapshot(
267+
cache_key=cache_key,
268+
total_bytes=total_bytes,
269+
current_vram_bytes=current_vram_bytes,
270+
)
271+
272+
return overview
273+
198274
@synchronized
199275
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
200276
"""Retrieve a model from the cache.
@@ -208,6 +284,8 @@ def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
208284
if self.stats:
209285
self.stats.hits += 1
210286
else:
287+
for cb in self._on_cache_miss_callbacks:
288+
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
211289
if self.stats:
212290
self.stats.misses += 1
213291
self._logger.debug(f"Cache miss: {key}")
@@ -229,6 +307,8 @@ def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
229307
self._cache_stack.append(key)
230308

231309
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
310+
for cb in self._on_cache_hit_callbacks:
311+
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
232312
return cache_entry
233313

234314
@synchronized
@@ -649,6 +729,13 @@ def make_room(self, bytes_needed: int) -> None:
649729
# immediately when their reference count hits 0.
650730
if self.stats:
651731
self.stats.cleared = models_cleared
732+
for cb in self._on_cache_models_cleared_callbacks:
733+
cb(
734+
models_cleared=models_cleared,
735+
bytes_requested=bytes_needed,
736+
bytes_freed=ram_bytes_freed,
737+
cache_snapshot=self._get_cache_snapshot(),
738+
)
652739
gc.collect()
653740

654741
TorchDevice.empty_cache()

0 commit comments

Comments
 (0)