Skip to content

Commit c5eec1d

Browse files
author
Lincoln Stein
committed
resolve conflicts with cherry-pick
1 parent 6932f27 commit c5eec1d

File tree

4 files changed

+40
-122
lines changed

4 files changed

+40
-122
lines changed

invokeai/backend/model_manager/load/model_cache/model_cache_base.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,10 @@ class CacheRecord(Generic[T]):
5252
Elements of the cache:
5353
5454
key: Unique key for each model, same as used in the models database.
55-
model: Model in memory.
55+
model: Read-only copy of the model *without weights* residing in the "meta device"
5656
state_dict: A read-only copy of the model's state dict in RAM. It will be
5757
used as a template for creating a copy in the VRAM.
5858
size: Size of the model
59-
loaded: True if the model's state dict is currently in VRAM
6059
6160
Before a model is executed, the state_dict template is copied into VRAM,
6261
and then injected into the model. When the model is finished, the VRAM
@@ -72,25 +71,7 @@ class CacheRecord(Generic[T]):
7271
key: str
7372
size: int
7473
model: T
75-
device: torch.device
7674
state_dict: Optional[Dict[str, torch.Tensor]]
77-
size: int
78-
loaded: bool = False
79-
_locks: int = 0
80-
81-
def lock(self) -> None:
82-
"""Lock this record."""
83-
self._locks += 1
84-
85-
def unlock(self) -> None:
86-
"""Unlock this record."""
87-
self._locks -= 1
88-
assert self._locks >= 0
89-
90-
@property
91-
def locked(self) -> bool:
92-
"""Return true if record is locked."""
93-
return self._locks > 0
9475

9576

9677
@dataclass

invokeai/backend/model_manager/load/model_cache/model_cache_default.py

Lines changed: 29 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from invokeai.backend.util.devices import TorchDevice
3737
from invokeai.backend.util.logging import InvokeAILogger
3838

39+
from ..optimizations import skip_torch_weight_init
3940
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
4041
from .model_locker import ModelLocker
4142

@@ -221,8 +222,12 @@ def put(
221222
size = calc_model_size_by_data(model)
222223
self.make_room(size)
223224

224-
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
225-
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
225+
if isinstance(model, torch.nn.Module):
226+
state_dict = model.state_dict() # keep a master copy of the state dict
227+
model = model.to(device="meta") # and keep a template in the meta device
228+
else:
229+
state_dict = None
230+
cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size)
226231
self._cached_models[key] = cache_record
227232
self._cache_stack.append(key)
228233

@@ -284,48 +289,20 @@ def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType]
284289
else:
285290
return model_key
286291

287-
def offload_unlocked_models(self, size_required: int) -> None:
288-
"""Move any unused models from VRAM."""
289-
device = self.get_execution_device()
290-
reserved = self._max_vram_cache_size * GIG
291-
vram_in_use = torch.cuda.memory_allocated(device) + size_required
292-
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
293-
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
294-
if vram_in_use <= reserved:
295-
break
296-
if not cache_entry.loaded:
297-
continue
298-
if cache_entry.device is not device:
299-
continue
300-
if not cache_entry.locked:
301-
self.move_model_to_device(cache_entry, self.storage_device)
302-
cache_entry.loaded = False
303-
vram_in_use = torch.cuda.memory_allocated() + size_required
304-
self.logger.debug(
305-
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
306-
)
307-
308-
TorchDevice.empty_cache()
309-
310-
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
311-
"""Move model into the indicated device.
292+
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
293+
"""Move a copy of the model into the indicated device and return it.
312294
313295
:param cache_entry: The CacheRecord for the model
314296
:param target_device: The torch.device to move the model into
315297
316298
May raise a torch.cuda.OutOfMemoryError
317299
"""
318-
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
319-
source_device = cache_entry.device
320-
321-
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
322-
# This would need to be revised to support multi-GPU.
323-
if torch.device(source_device).type == torch.device(target_device).type:
324-
return
300+
self.logger.info(f"Called to move {cache_entry.key} to {target_device}")
325301

326-
# Some models don't have a `to` method, in which case they run in RAM/CPU.
327-
if not hasattr(cache_entry.model, "to"):
328-
return
302+
# Some models don't have a state dictionary, in which case the
303+
# stored model will still reside in CPU
304+
if cache_entry.state_dict is None:
305+
return cache_entry.model
329306

330307
# This roundabout method for moving the model around is done to avoid
331308
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
@@ -338,27 +315,25 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
338315
# in RAM into the model. So this operation is very fast.
339316
start_model_to_time = time.time()
340317
snapshot_before = self._capture_memory_snapshot()
341-
342318
try:
343-
if cache_entry.state_dict is not None:
344-
assert hasattr(cache_entry.model, "load_state_dict")
345-
if target_device == self.storage_device:
346-
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
319+
assert isinstance(cache_entry.model, torch.nn.Module)
320+
template = cache_entry.model
321+
cls = template.__class__
322+
with skip_torch_weight_init():
323+
if hasattr(cls, "from_config"):
324+
working_model = template.__class__.from_config(template.config) # diffusers style
347325
else:
348-
new_dict: Dict[str, torch.Tensor] = {}
349-
for k, v in cache_entry.state_dict.items():
350-
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
351-
cache_entry.model.load_state_dict(new_dict, assign=True)
352-
cache_entry.model.to(target_device, non_blocking=True)
353-
cache_entry.device = target_device
326+
working_model = template.__class__(config=template.config) # transformers style (sigh)
327+
working_model.to(device=target_device, dtype=self._precision)
328+
working_model.load_state_dict(cache_entry.state_dict)
354329
except Exception as e: # blow away cache entry
355330
self._delete_cache_entry(cache_entry)
356331
raise e
357332

358333
snapshot_after = self._capture_memory_snapshot()
359334
end_model_to_time = time.time()
360-
self.logger.debug(
361-
f"Moved model '{cache_entry.key}' from {source_device} to"
335+
self.logger.info(
336+
f"Moved model '{cache_entry.key}' to"
362337
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
363338
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
364339
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
@@ -380,34 +355,21 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
380355
abs_tol=10 * MB,
381356
):
382357
self.logger.debug(
383-
f"Moving model '{cache_entry.key}' from {source_device} to"
358+
f"Moving model '{cache_entry.key}' from to"
384359
f" {target_device} caused an unexpected change in VRAM usage. The model's"
385360
" estimated size may be incorrect. Estimated model size:"
386361
f" {(cache_entry.size/GIG):.3f} GB.\n"
387362
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
388363
)
364+
return working_model
389365

390366
def print_cuda_stats(self) -> None:
391367
"""Log CUDA diagnostics."""
392368
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
393369
ram = "%4.2fG" % (self.cache_size() / GIG)
394370

395-
in_ram_models = 0
396-
in_vram_models = 0
397-
locked_in_vram_models = 0
398-
for cache_record in self._cached_models.values():
399-
if hasattr(cache_record.model, "device"):
400-
if cache_record.model.device == self.storage_device:
401-
in_ram_models += 1
402-
else:
403-
in_vram_models += 1
404-
if cache_record.locked:
405-
locked_in_vram_models += 1
406-
407-
self.logger.debug(
408-
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
409-
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
410-
)
371+
in_ram_models = len(self._cached_models)
372+
self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
411373

412374
def make_room(self, size: int) -> None:
413375
"""Make enough room in the cache to accommodate a new model of indicated size."""
@@ -433,29 +395,6 @@ def make_room(self, size: int) -> None:
433395

434396
refs = sys.getrefcount(cache_entry.model)
435397

436-
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
437-
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
438-
# https://docs.python.org/3/library/gc.html#gc.get_referrers
439-
440-
# manualy clear local variable references of just finished function calls
441-
# for some reason python don't want to collect it even by gc.collect() immidiately
442-
if refs > 2:
443-
while True:
444-
cleared = False
445-
for referrer in gc.get_referrers(cache_entry.model):
446-
if type(referrer).__name__ == "frame":
447-
# RuntimeError: cannot clear an executing frame
448-
with suppress(RuntimeError):
449-
referrer.clear()
450-
cleared = True
451-
# break
452-
453-
# repeat if referrers changes(due to frame clear), else exit loop
454-
if cleared:
455-
gc.collect()
456-
else:
457-
break
458-
459398
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
460399
self.logger.debug(
461400
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"

invokeai/backend/model_manager/load/model_cache/model_locker.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,22 @@ def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
3737

3838
def lock(self) -> AnyModel:
3939
"""Move the model into the execution device (GPU) and lock it."""
40-
self._cache_entry.lock()
4140
try:
4241
device = self._cache.get_execution_device()
43-
self._cache.offload_unlocked_models(self._cache_entry.size)
44-
self._cache.move_model_to_device(self._cache_entry, device)
45-
self._cache_entry.loaded = True
46-
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {device}")
42+
model_on_device = self._cache.model_to_device(self._cache_entry, device)
43+
self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
4744
self._cache.print_cuda_stats()
4845
except torch.cuda.OutOfMemoryError:
4946
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
50-
self._cache_entry.unlock()
5147
raise
5248
except Exception:
53-
self._cache_entry.unlock()
5449
raise
5550

56-
return self.model
51+
return model_on_device
5752

53+
# It is no longer necessary to move the model out of VRAM
54+
# because it will be removed when it goes out of scope
55+
# in the caller's context
5856
def unlock(self) -> None:
5957
"""Call upon exit from context."""
60-
self._cache_entry.unlock()
6158
self._cache.print_cuda_stats()

invokeai/backend/model_patcher.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,7 @@ def apply_lora(
129129
dtype = module.weight.dtype
130130

131131
if module_key not in original_weights:
132-
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
133-
original_weights[module_key] = model_state_dict[module_key + ".weight"]
134-
else:
132+
if model_state_dict is None: # no CPU copy of the state dict was provided
135133
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
136134

137135
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
@@ -158,6 +156,9 @@ def apply_lora(
158156
yield # wait for context manager exit
159157

160158
finally:
159+
# LS check: for now, we are not reusing models in VRAM but re-copying them each time they are needed.
160+
# Therefore it should not be necessary to copy the original model weights back.
161+
# This needs to be fixed before resurrecting the VRAM cache.
161162
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
162163
with torch.no_grad():
163164
for module_key, weight in original_weights.items():

0 commit comments

Comments
 (0)