36
36
from invokeai .backend .util .devices import TorchDevice
37
37
from invokeai .backend .util .logging import InvokeAILogger
38
38
39
+ from ..optimizations import skip_torch_weight_init
39
40
from .model_cache_base import CacheRecord , CacheStats , ModelCacheBase , ModelLockerBase
40
41
from .model_locker import ModelLocker
41
42
@@ -221,8 +222,12 @@ def put(
221
222
size = calc_model_size_by_data (model )
222
223
self .make_room (size )
223
224
224
- state_dict = model .state_dict () if isinstance (model , torch .nn .Module ) else None
225
- cache_record = CacheRecord (key = key , model = model , device = self .storage_device , state_dict = state_dict , size = size )
225
+ if isinstance (model , torch .nn .Module ):
226
+ state_dict = model .state_dict () # keep a master copy of the state dict
227
+ model = model .to (device = "meta" ) # and keep a template in the meta device
228
+ else :
229
+ state_dict = None
230
+ cache_record = CacheRecord (key = key , model = model , state_dict = state_dict , size = size )
226
231
self ._cached_models [key ] = cache_record
227
232
self ._cache_stack .append (key )
228
233
@@ -284,48 +289,20 @@ def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType]
284
289
else :
285
290
return model_key
286
291
287
- def offload_unlocked_models (self , size_required : int ) -> None :
288
- """Move any unused models from VRAM."""
289
- device = self .get_execution_device ()
290
- reserved = self ._max_vram_cache_size * GIG
291
- vram_in_use = torch .cuda .memory_allocated (device ) + size_required
292
- self .logger .debug (f"{ (vram_in_use / GIG ):.2f} GB VRAM needed for models; max allowed={ (reserved / GIG ):.2f} GB" )
293
- for _ , cache_entry in sorted (self ._cached_models .items (), key = lambda x : x [1 ].size ):
294
- if vram_in_use <= reserved :
295
- break
296
- if not cache_entry .loaded :
297
- continue
298
- if cache_entry .device is not device :
299
- continue
300
- if not cache_entry .locked :
301
- self .move_model_to_device (cache_entry , self .storage_device )
302
- cache_entry .loaded = False
303
- vram_in_use = torch .cuda .memory_allocated () + size_required
304
- self .logger .debug (
305
- f"Removing { cache_entry .key } from VRAM to free { (cache_entry .size / GIG ):.2f} GB; vram free = { (torch .cuda .memory_allocated ()/ GIG ):.2f} GB"
306
- )
307
-
308
- TorchDevice .empty_cache ()
309
-
310
- def move_model_to_device (self , cache_entry : CacheRecord [AnyModel ], target_device : torch .device ) -> None :
311
- """Move model into the indicated device.
292
+ def model_to_device (self , cache_entry : CacheRecord [AnyModel ], target_device : torch .device ) -> AnyModel :
293
+ """Move a copy of the model into the indicated device and return it.
312
294
313
295
:param cache_entry: The CacheRecord for the model
314
296
:param target_device: The torch.device to move the model into
315
297
316
298
May raise a torch.cuda.OutOfMemoryError
317
299
"""
318
- self .logger .debug (f"Called to move { cache_entry .key } to { target_device } " )
319
- source_device = cache_entry .device
320
-
321
- # Note: We compare device types only so that 'cuda' == 'cuda:0'.
322
- # This would need to be revised to support multi-GPU.
323
- if torch .device (source_device ).type == torch .device (target_device ).type :
324
- return
300
+ self .logger .info (f"Called to move { cache_entry .key } to { target_device } " )
325
301
326
- # Some models don't have a `to` method, in which case they run in RAM/CPU.
327
- if not hasattr (cache_entry .model , "to" ):
328
- return
302
+ # Some models don't have a state dictionary, in which case the
303
+ # stored model will still reside in CPU
304
+ if cache_entry .state_dict is None :
305
+ return cache_entry .model
329
306
330
307
# This roundabout method for moving the model around is done to avoid
331
308
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
@@ -338,27 +315,25 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
338
315
# in RAM into the model. So this operation is very fast.
339
316
start_model_to_time = time .time ()
340
317
snapshot_before = self ._capture_memory_snapshot ()
341
-
342
318
try :
343
- if cache_entry .state_dict is not None :
344
- assert hasattr (cache_entry .model , "load_state_dict" )
345
- if target_device == self .storage_device :
346
- cache_entry .model .load_state_dict (cache_entry .state_dict , assign = True )
319
+ assert isinstance (cache_entry .model , torch .nn .Module )
320
+ template = cache_entry .model
321
+ cls = template .__class__
322
+ with skip_torch_weight_init ():
323
+ if hasattr (cls , "from_config" ):
324
+ working_model = template .__class__ .from_config (template .config ) # diffusers style
347
325
else :
348
- new_dict : Dict [str , torch .Tensor ] = {}
349
- for k , v in cache_entry .state_dict .items ():
350
- new_dict [k ] = v .to (torch .device (target_device ), copy = True , non_blocking = True )
351
- cache_entry .model .load_state_dict (new_dict , assign = True )
352
- cache_entry .model .to (target_device , non_blocking = True )
353
- cache_entry .device = target_device
326
+ working_model = template .__class__ (config = template .config ) # transformers style (sigh)
327
+ working_model .to (device = target_device , dtype = self ._precision )
328
+ working_model .load_state_dict (cache_entry .state_dict )
354
329
except Exception as e : # blow away cache entry
355
330
self ._delete_cache_entry (cache_entry )
356
331
raise e
357
332
358
333
snapshot_after = self ._capture_memory_snapshot ()
359
334
end_model_to_time = time .time ()
360
- self .logger .debug (
361
- f"Moved model '{ cache_entry .key } ' from { source_device } to"
335
+ self .logger .info (
336
+ f"Moved model '{ cache_entry .key } ' to"
362
337
f" { target_device } in { (end_model_to_time - start_model_to_time ):.2f} s."
363
338
f"Estimated model size: { (cache_entry .size / GIG ):.3f} GB."
364
339
f"{ get_pretty_snapshot_diff (snapshot_before , snapshot_after )} "
@@ -380,34 +355,21 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
380
355
abs_tol = 10 * MB ,
381
356
):
382
357
self .logger .debug (
383
- f"Moving model '{ cache_entry .key } ' from { source_device } to"
358
+ f"Moving model '{ cache_entry .key } ' from to"
384
359
f" { target_device } caused an unexpected change in VRAM usage. The model's"
385
360
" estimated size may be incorrect. Estimated model size:"
386
361
f" { (cache_entry .size / GIG ):.3f} GB.\n "
387
362
f"{ get_pretty_snapshot_diff (snapshot_before , snapshot_after )} "
388
363
)
364
+ return working_model
389
365
390
366
def print_cuda_stats (self ) -> None :
391
367
"""Log CUDA diagnostics."""
392
368
vram = "%4.2fG" % (torch .cuda .memory_allocated () / GIG )
393
369
ram = "%4.2fG" % (self .cache_size () / GIG )
394
370
395
- in_ram_models = 0
396
- in_vram_models = 0
397
- locked_in_vram_models = 0
398
- for cache_record in self ._cached_models .values ():
399
- if hasattr (cache_record .model , "device" ):
400
- if cache_record .model .device == self .storage_device :
401
- in_ram_models += 1
402
- else :
403
- in_vram_models += 1
404
- if cache_record .locked :
405
- locked_in_vram_models += 1
406
-
407
- self .logger .debug (
408
- f"Current VRAM/RAM usage: { vram } /{ ram } ; models_in_ram/models_in_vram(locked) ="
409
- f" { in_ram_models } /{ in_vram_models } ({ locked_in_vram_models } )"
410
- )
371
+ in_ram_models = len (self ._cached_models )
372
+ self .logger .debug (f"Current VRAM/RAM usage for { in_ram_models } models: { vram } /{ ram } " )
411
373
412
374
def make_room (self , size : int ) -> None :
413
375
"""Make enough room in the cache to accommodate a new model of indicated size."""
@@ -433,29 +395,6 @@ def make_room(self, size: int) -> None:
433
395
434
396
refs = sys .getrefcount (cache_entry .model )
435
397
436
- # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
437
- # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
438
- # https://docs.python.org/3/library/gc.html#gc.get_referrers
439
-
440
- # manualy clear local variable references of just finished function calls
441
- # for some reason python don't want to collect it even by gc.collect() immidiately
442
- if refs > 2 :
443
- while True :
444
- cleared = False
445
- for referrer in gc .get_referrers (cache_entry .model ):
446
- if type (referrer ).__name__ == "frame" :
447
- # RuntimeError: cannot clear an executing frame
448
- with suppress (RuntimeError ):
449
- referrer .clear ()
450
- cleared = True
451
- # break
452
-
453
- # repeat if referrers changes(due to frame clear), else exit loop
454
- if cleared :
455
- gc .collect ()
456
- else :
457
- break
458
-
459
398
device = cache_entry .model .device if hasattr (cache_entry .model , "device" ) else None
460
399
self .logger .debug (
461
400
f"Model: { model_key } , locks: { cache_entry ._locks } , device: { device } , loaded: { cache_entry .loaded } ,"
0 commit comments