Skip to content

Commit fd3a489

Browse files
add offload disk
1 parent 4e09363 commit fd3a489

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

diffsynth_engine/configs/pipeline.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class BaseConfig:
1616
vae_tile_stride: int | Tuple[int, int] = 256
1717
device: str = "cuda"
1818
offload_mode: Optional[str] = None
19+
offload_to_disk: bool = False
1920

2021

2122
@dataclass
@@ -62,11 +63,13 @@ def basic_config(
6263
model_path: str | os.PathLike | List[str | os.PathLike],
6364
device: str = "cuda",
6465
offload_mode: Optional[str] = None,
66+
offload_to_disk: bool = False,
6567
) -> "SDPipelineConfig":
6668
return cls(
6769
model_path=model_path,
6870
device=device,
6971
offload_mode=offload_mode,
72+
offload_to_disk=offload_to_disk,
7073
)
7174

7275

@@ -87,11 +90,13 @@ def basic_config(
8790
model_path: str | os.PathLike | List[str | os.PathLike],
8891
device: str = "cuda",
8992
offload_mode: Optional[str] = None,
93+
offload_to_disk: bool = False,
9094
) -> "SDXLPipelineConfig":
9195
return cls(
9296
model_path=model_path,
9397
device=device,
9498
offload_mode=offload_mode,
99+
offload_to_disk=offload_to_disk,
95100
)
96101

97102

@@ -116,13 +121,15 @@ def basic_config(
116121
device: str = "cuda",
117122
parallelism: int = 1,
118123
offload_mode: Optional[str] = None,
124+
offload_to_disk: bool = False,
119125
) -> "FluxPipelineConfig":
120126
return cls(
121127
model_path=model_path,
122128
device=device,
123129
parallelism=parallelism,
124130
use_fsdp=True,
125131
offload_mode=offload_mode,
132+
offload_to_disk=offload_to_disk,
126133
)
127134

128135
def __post_init__(self):
@@ -160,6 +167,7 @@ def basic_config(
160167
device: str = "cuda",
161168
parallelism: int = 1,
162169
offload_mode: Optional[str] = None,
170+
offload_to_disk: bool = False,
163171
) -> "WanPipelineConfig":
164172
return cls(
165173
model_path=model_path,
@@ -169,6 +177,7 @@ def basic_config(
169177
use_cfg_parallel=True,
170178
use_fsdp=True,
171179
offload_mode=offload_mode,
180+
offload_to_disk=offload_to_disk,
172181
)
173182

174183
def __post_init__(self):
@@ -196,6 +205,7 @@ def basic_config(
196205
device: str = "cuda",
197206
parallelism: int = 1,
198207
offload_mode: Optional[str] = None,
208+
offload_to_disk: bool = False,
199209
) -> "QwenImagePipelineConfig":
200210
return cls(
201211
model_path=model_path,
@@ -206,6 +216,7 @@ def basic_config(
206216
use_cfg_parallel=True,
207217
use_fsdp=True,
208218
offload_mode=offload_mode,
219+
offload_to_disk=offload_to_disk,
209220
)
210221

211222
def __post_init__(self):

diffsynth_engine/pipelines/base.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
self.offload_mode = None
4242
self.model_names = []
4343
self._offload_param_dict = {}
44+
self.offload_to_disk = False
4445

4546
@classmethod
4647
def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
@@ -228,19 +229,23 @@ def eval(self):
228229
model.eval()
229230
return self
230231

231-
def enable_cpu_offload(self, offload_mode: str):
232-
valid_offload_mode = ("cpu_offload", "sequential_cpu_offload")
232+
def enable_cpu_offload(self, offload_mode: str | None, offload_to_disk:bool = False):
233+
valid_offload_mode = ("cpu_offload", "sequential_cpu_offload", "disable", None)
233234
if offload_mode not in valid_offload_mode:
234235
raise ValueError(f"offload_mode must be one of {valid_offload_mode}, but got {offload_mode}")
235236
if self.device == "cpu" or self.device == "mps":
236237
logger.warning("must set an non cpu device for pipeline before calling enable_cpu_offload")
237238
return
238-
if offload_mode == "cpu_offload":
239+
if offload_mode is None or offload_mode == "disable":
240+
self._disable_offload()
241+
elif offload_mode == "cpu_offload":
239242
self._enable_model_cpu_offload()
240243
elif offload_mode == "sequential_cpu_offload":
241244
self._enable_sequential_cpu_offload()
245+
self.offload_to_disk = offload_to_disk
242246

243-
def _enable_model_cpu_offload(self):
247+
248+
def _enable_model_cpu_offload(self):
244249
for model_name in self.model_names:
245250
model = getattr(self, model_name)
246251
if model is not None:
@@ -253,6 +258,15 @@ def _enable_sequential_cpu_offload(self):
253258
if model is not None:
254259
enable_sequential_cpu_offload(model, self.device)
255260
self.offload_mode = "sequential_cpu_offload"
261+
262+
def _disable_offload(self):
263+
self.offload_mode = None
264+
self._offload_param_dict = {}
265+
for model_name in self.model_names:
266+
model = getattr(self, model_name)
267+
if model is not None:
268+
model.to(self.device)
269+
256270

257271
def enable_fp8_autocast(
258272
self, model_names: List[str], compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False
@@ -282,10 +296,26 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
282296
# load the needed models to device
283297
for model_name in load_model_names:
284298
model = getattr(self, model_name)
299+
if model is None:
300+
raise ValueError(f"model {model_name} is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True")
285301
if model is not None and (p := next(model.parameters(), None)) is not None and p.device.type != self.device:
286302
model.to(self.device)
287303
# fresh the cuda cache
288304
empty_cache()
289305

306+
def model_lifecycle_finish(self, model_names: List[str] | None = None):
307+
if not self.offload_to_disk or self.offload_mode is None:
308+
return
309+
for model_name in model_names:
310+
model = getattr(self, model_name)
311+
del model
312+
if model_name in self._offload_param_dict:
313+
del self._offload_param_dict[model_name]
314+
setattr(self, model_name, None)
315+
print(f"model {model_name} has been deleted from memory")
316+
logger.info(f"model {model_name} has been deleted from memory")
317+
empty_cache()
318+
319+
290320
def compile(self):
291321
raise NotImplementedError(f"{self.__class__.__name__} does not support compile")

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipe
208208
pipe.eval()
209209

210210
if config.offload_mode is not None:
211-
pipe.enable_cpu_offload(config.offload_mode)
211+
pipe.enable_cpu_offload(config.offload_mode, config.offload_to_disk)
212212

213213
if config.parallelism > 1:
214214
pipe = ParallelWrapper(
@@ -393,6 +393,7 @@ def __call__(
393393
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(negative_prompt, 1, 4096)
394394
else:
395395
negative_prompt_embeds, negative_prompt_embeds_mask = None, None
396+
self.model_lifecycle_finish(["encoder"])
396397

397398
hide_progress = dist.is_initialized() and dist.get_rank() != 0
398399
for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
@@ -412,6 +413,7 @@ def __call__(
412413
# UI
413414
if progress_callback is not None:
414415
progress_callback(i, len(timesteps), "DENOISING")
416+
self.model_lifecycle_finish(["dit"])
415417
# Decode image
416418
self.load_models_to_device(["vae"])
417419
latents = rearrange(latents, "B C H W -> B C 1 H W")
@@ -423,5 +425,6 @@ def __call__(
423425
)
424426
image = self.vae_output_to_image(vae_output)
425427
# Offload all models
428+
self.model_lifecycle_finish(["vae"])
426429
self.load_models_to_device([])
427430
return image

0 commit comments

Comments
 (0)