Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions diffsynth_engine/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def eval(self):
model.eval()
return self

def enable_cpu_offload(self, offload_mode: str | None, offload_to_disk:bool = False):
def enable_cpu_offload(self, offload_mode: str | None, offload_to_disk: bool = False):
valid_offload_mode = ("cpu_offload", "sequential_cpu_offload", "disable", None)
if offload_mode not in valid_offload_mode:
raise ValueError(f"offload_mode must be one of {valid_offload_mode}, but got {offload_mode}")
Expand All @@ -244,8 +244,7 @@ def enable_cpu_offload(self, offload_mode: str | None, offload_to_disk:bool = Fa
self._enable_sequential_cpu_offload()
self.offload_to_disk = offload_to_disk


def _enable_model_cpu_offload(self):
def _enable_model_cpu_offload(self):
for model_name in self.model_names:
model = getattr(self, model_name)
if model is not None:
Expand All @@ -258,23 +257,22 @@ def _enable_sequential_cpu_offload(self):
if model is not None:
enable_sequential_cpu_offload(model, self.device)
self.offload_mode = "sequential_cpu_offload"

def _disable_offload(self):
self.offload_mode = None
self._offload_param_dict = {}
self.offload_mode = None
self._offload_param_dict = {}
for model_name in self.model_names:
model = getattr(self, model_name)
if model is not None:
model.to(self.device)


def enable_fp8_autocast(
self, model_names: List[str], compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False
):
for model_name in model_names:
model = getattr(self, model_name)
if model is not None:
model.to(device=self.device, dtype=torch.float8_e4m3fn)
model.to(dtype=torch.float8_e4m3fn)
enable_fp8_autocast(model, compute_dtype, use_fp8_linear)
self.fp8_autocast_enabled = True

Expand All @@ -298,15 +296,17 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
for model_name in load_model_names:
model = getattr(self, model_name)
if model is None:
raise ValueError(f"model {model_name} is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True")
raise ValueError(
f"model {model_name} is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True"
)
if model is not None and (p := next(model.parameters(), None)) is not None and p.device.type != self.device:
model.to(self.device)
# fresh the cuda cache
empty_cache()

def model_lifecycle_finish(self, model_names: List[str] | None = None):
if not self.offload_to_disk or self.offload_mode is None:
return
return
for model_name in model_names:
model = getattr(self, model_name)
del model
Expand All @@ -316,7 +316,6 @@ def model_lifecycle_finish(self, model_names: List[str] | None = None):
print(f"model {model_name} has been deleted from memory")
logger.info(f"model {model_name} has been deleted from memory")
empty_cache()



def compile(self):
raise NotImplementedError(f"{self.__class__.__name__} does not support compile")