19
19
"""
20
20
21
21
import gc
22
- import math
23
22
import sys
24
23
import threading
25
- import time
26
24
from contextlib import contextmanager , suppress
27
25
from logging import Logger
28
26
from threading import BoundedSemaphore
31
29
import torch
32
30
33
31
from invokeai .backend .model_manager import AnyModel , SubModelType
34
- from invokeai .backend .model_manager .load .memory_snapshot import MemorySnapshot , get_pretty_snapshot_diff
32
+ from invokeai .backend .model_manager .load .memory_snapshot import MemorySnapshot
35
33
from invokeai .backend .util .devices import TorchDevice
36
34
from invokeai .backend .util .logging import InvokeAILogger
37
35
42
40
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
43
41
DEFAULT_MAX_CACHE_SIZE = 6.0
44
42
45
- # amount of GPU memory to hold in reserve for use by generations (GB)
46
- # Empirically this value seems to improve performance without starving other
47
- # processes.
48
- DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
49
-
50
43
# actual size of a gig
51
44
GIG = 1073741824
52
45
@@ -60,12 +53,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
60
53
def __init__ (
61
54
self ,
62
55
max_cache_size : float = DEFAULT_MAX_CACHE_SIZE ,
63
- max_vram_cache_size : float = DEFAULT_MAX_VRAM_CACHE_SIZE ,
64
56
storage_device : torch .device = torch .device ("cpu" ),
65
57
execution_devices : Optional [Set [torch .device ]] = None ,
66
58
precision : torch .dtype = torch .float16 ,
67
59
sequential_offload : bool = False ,
68
- lazy_offloading : bool = True ,
69
60
sha_chunksize : int = 16777216 ,
70
61
log_memory_usage : bool = False ,
71
62
logger : Optional [Logger ] = None ,
@@ -76,18 +67,14 @@ def __init__(
76
67
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
77
68
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
78
69
:param precision: Precision for loaded models [torch.float16]
79
- :param lazy_offloading: Keep model in VRAM until another model needs to be loaded
80
70
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
81
71
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
82
72
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
83
73
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
84
74
behaviour.
85
75
"""
86
- # allow lazy offloading only when vram cache enabled
87
- self ._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
88
76
self ._precision : torch .dtype = precision
89
77
self ._max_cache_size : float = max_cache_size
90
- self ._max_vram_cache_size : float = max_vram_cache_size
91
78
self ._storage_device : torch .device = storage_device
92
79
self ._ram_lock = threading .Lock ()
93
80
self ._logger = logger or InvokeAILogger .get_logger (self .__class__ .__name__ )
@@ -111,11 +98,6 @@ def logger(self) -> Logger:
111
98
"""Return the logger used by the cache."""
112
99
return self ._logger
113
100
114
- @property
115
- def lazy_offloading (self ) -> bool :
116
- """Return true if the cache is configured to lazily offload models in VRAM."""
117
- return self ._lazy_offloading
118
-
119
101
@property
120
102
def storage_device (self ) -> torch .device :
121
103
"""Return the storage device (e.g. "CPU" for RAM)."""
@@ -233,10 +215,9 @@ def put(
233
215
if key in self ._cached_models :
234
216
return
235
217
self .make_room (size )
236
- state_dict = model .state_dict () if isinstance (model , torch .nn .Module ) else None
237
- cache_record = CacheRecord (key = key , model = model , device = self .storage_device , state_dict = state_dict , size = size )
238
- self ._cached_models [key ] = cache_record
239
- self ._cache_stack .append (key )
218
+ cache_record = CacheRecord (key , model = model , size = size )
219
+ self ._cached_models [key ] = cache_record
220
+ self ._cache_stack .append (key )
240
221
241
222
def get (
242
223
self ,
@@ -296,107 +277,6 @@ def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType]
296
277
else :
297
278
return model_key
298
279
299
- def offload_unlocked_models (self , size_required : int ) -> None :
300
- """Move any unused models from VRAM."""
301
- reserved = self ._max_vram_cache_size * GIG
302
- vram_in_use = torch .cuda .memory_allocated () + size_required
303
- self .logger .debug (f"{ (vram_in_use / GIG ):.2f} GB VRAM needed for models; max allowed={ (reserved / GIG ):.2f} GB" )
304
- for _ , cache_entry in sorted (self ._cached_models .items (), key = lambda x : x [1 ].size ):
305
- if vram_in_use <= reserved :
306
- break
307
- if not cache_entry .loaded :
308
- continue
309
- if not cache_entry .locked :
310
- self .move_model_to_device (cache_entry , self .storage_device )
311
- cache_entry .loaded = False
312
- vram_in_use = torch .cuda .memory_allocated () + size_required
313
- self .logger .debug (
314
- f"Removing { cache_entry .key } from VRAM to free { (cache_entry .size / GIG ):.2f} GB; vram free = { (torch .cuda .memory_allocated ()/ GIG ):.2f} GB"
315
- )
316
-
317
- TorchDevice .empty_cache ()
318
-
319
- def move_model_to_device (self , cache_entry : CacheRecord [AnyModel ], target_device : torch .device ) -> None :
320
- """Move model into the indicated device.
321
-
322
- :param cache_entry: The CacheRecord for the model
323
- :param target_device: The torch.device to move the model into
324
-
325
- May raise a torch.cuda.OutOfMemoryError
326
- """
327
- # These attributes are not in the base ModelMixin class but in various derived classes.
328
- # Some models don't have these attributes, in which case they run in RAM/CPU.
329
- self .logger .debug (f"Called to move { cache_entry .key } to { target_device } " )
330
- if not (hasattr (cache_entry .model , "device" ) and hasattr (cache_entry .model , "to" )):
331
- return
332
-
333
- source_device = cache_entry .device
334
-
335
- # Note: We compare device types only so that 'cuda' == 'cuda:0'.
336
- # This would need to be revised to support multi-GPU.
337
- if torch .device (source_device ).type == torch .device (target_device ).type :
338
- return
339
-
340
- # This roundabout method for moving the model around is done to avoid
341
- # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
342
- # When moving to VRAM, we copy (not move) each element of the state dict from
343
- # RAM to a new state dict in VRAM, and then inject it into the model.
344
- # This operation is slightly faster than running `to()` on the whole model.
345
- #
346
- # When the model needs to be removed from VRAM we simply delete the copy
347
- # of the state dict in VRAM, and reinject the state dict that is cached
348
- # in RAM into the model. So this operation is very fast.
349
- start_model_to_time = time .time ()
350
- snapshot_before = self ._capture_memory_snapshot ()
351
-
352
- try :
353
- if cache_entry .state_dict is not None :
354
- assert hasattr (cache_entry .model , "load_state_dict" )
355
- if target_device == self .storage_device :
356
- cache_entry .model .load_state_dict (cache_entry .state_dict , assign = True )
357
- else :
358
- new_dict : Dict [str , torch .Tensor ] = {}
359
- for k , v in cache_entry .state_dict .items ():
360
- new_dict [k ] = v .to (torch .device (target_device ), copy = True )
361
- cache_entry .model .load_state_dict (new_dict , assign = True )
362
- cache_entry .model .to (target_device )
363
- cache_entry .device = target_device
364
- except Exception as e : # blow away cache entry
365
- self ._delete_cache_entry (cache_entry )
366
- raise e
367
-
368
- snapshot_after = self ._capture_memory_snapshot ()
369
- end_model_to_time = time .time ()
370
- self .logger .debug (
371
- f"Moved model '{ cache_entry .key } ' from { source_device } to"
372
- f" { target_device } in { (end_model_to_time - start_model_to_time ):.2f} s."
373
- f"Estimated model size: { (cache_entry .size / GIG ):.3f} GB."
374
- f"{ get_pretty_snapshot_diff (snapshot_before , snapshot_after )} "
375
- )
376
-
377
- if (
378
- snapshot_before is not None
379
- and snapshot_after is not None
380
- and snapshot_before .vram is not None
381
- and snapshot_after .vram is not None
382
- ):
383
- vram_change = abs (snapshot_before .vram - snapshot_after .vram )
384
-
385
- # If the estimated model size does not match the change in VRAM, log a warning.
386
- if not math .isclose (
387
- vram_change ,
388
- cache_entry .size ,
389
- rel_tol = 0.1 ,
390
- abs_tol = 10 * MB ,
391
- ):
392
- self .logger .debug (
393
- f"Moving model '{ cache_entry .key } ' from { source_device } to"
394
- f" { target_device } caused an unexpected change in VRAM usage. The model's"
395
- " estimated size may be incorrect. Estimated model size:"
396
- f" { (cache_entry .size / GIG ):.3f} GB.\n "
397
- f"{ get_pretty_snapshot_diff (snapshot_before , snapshot_after )} "
398
- )
399
-
400
280
def print_cuda_stats (self ) -> None :
401
281
"""Log CUDA diagnostics."""
402
282
vram = "%4.2fG" % (torch .cuda .memory_allocated () / GIG )
@@ -440,12 +320,43 @@ def make_room(self, size: int) -> None:
440
320
while current_size + bytes_needed > maximum_size and pos < len (self ._cache_stack ):
441
321
model_key = self ._cache_stack [pos ]
442
322
cache_entry = self ._cached_models [model_key ]
323
+
324
+ refs = sys .getrefcount (cache_entry .model )
325
+
326
+ # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
327
+ # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
328
+ # https://docs.python.org/3/library/gc.html#gc.get_referrers
329
+
330
+ # manualy clear local variable references of just finished function calls
331
+ # for some reason python don't want to collect it even by gc.collect() immidiately
332
+ if refs > 2 :
333
+ while True :
334
+ cleared = False
335
+ for referrer in gc .get_referrers (cache_entry .model ):
336
+ if type (referrer ).__name__ == "frame" :
337
+ # RuntimeError: cannot clear an executing frame
338
+ with suppress (RuntimeError ):
339
+ referrer .clear ()
340
+ cleared = True
341
+ # break
342
+
343
+ # repeat if referrers changes(due to frame clear), else exit loop
344
+ if cleared :
345
+ gc .collect ()
346
+ else :
347
+ break
348
+
443
349
device = cache_entry .model .device if hasattr (cache_entry .model , "device" ) else None
444
350
self .logger .debug (
445
- f"Model: { model_key } , locks: { cache_entry ._locks } , device: { device } , loaded: { cache_entry .loaded } "
351
+ f"Model: { model_key } , locks: { cache_entry ._locks } , device: { device } , loaded: { cache_entry .loaded } ,"
352
+ f" refs: { refs } "
446
353
)
447
354
448
- if not cache_entry .locked :
355
+ # Expected refs:
356
+ # 1 from cache_entry
357
+ # 1 from getrefcount function
358
+ # 1 from onnx runtime object
359
+ if not cache_entry .locked and refs <= (3 if "onnx" in model_key else 2 ):
449
360
self .logger .debug (
450
361
f"Removing { model_key } from RAM cache to free at least { (size / GIG ):.2f} GB (-{ (cache_entry .size / GIG ):.2f} GB)"
451
362
)
0 commit comments