Skip to content

Commit 898ae75

Browse files
tenderness-gitakaitsuki-ii
authored andcommitted
fix cast fp8
1 parent 4e09363 commit 898ae75

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

diffsynth_engine/configs/pipeline.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class BaseConfig:
1616
vae_tile_stride: int | Tuple[int, int] = 256
1717
device: str = "cuda"
1818
offload_mode: Optional[str] = None
19+
offload_to_disk: bool = False
1920

2021

2122
@dataclass
@@ -62,11 +63,13 @@ def basic_config(
6263
model_path: str | os.PathLike | List[str | os.PathLike],
6364
device: str = "cuda",
6465
offload_mode: Optional[str] = None,
66+
offload_to_disk: bool = False,
6567
) -> "SDPipelineConfig":
6668
return cls(
6769
model_path=model_path,
6870
device=device,
6971
offload_mode=offload_mode,
72+
offload_to_disk=offload_to_disk,
7073
)
7174

7275

@@ -87,11 +90,13 @@ def basic_config(
8790
model_path: str | os.PathLike | List[str | os.PathLike],
8891
device: str = "cuda",
8992
offload_mode: Optional[str] = None,
93+
offload_to_disk: bool = False,
9094
) -> "SDXLPipelineConfig":
9195
return cls(
9296
model_path=model_path,
9397
device=device,
9498
offload_mode=offload_mode,
99+
offload_to_disk=offload_to_disk,
95100
)
96101

97102

@@ -116,13 +121,15 @@ def basic_config(
116121
device: str = "cuda",
117122
parallelism: int = 1,
118123
offload_mode: Optional[str] = None,
124+
offload_to_disk: bool = False,
119125
) -> "FluxPipelineConfig":
120126
return cls(
121127
model_path=model_path,
122128
device=device,
123129
parallelism=parallelism,
124130
use_fsdp=True,
125131
offload_mode=offload_mode,
132+
offload_to_disk=offload_to_disk,
126133
)
127134

128135
def __post_init__(self):
@@ -160,6 +167,7 @@ def basic_config(
160167
device: str = "cuda",
161168
parallelism: int = 1,
162169
offload_mode: Optional[str] = None,
170+
offload_to_disk: bool = False,
163171
) -> "WanPipelineConfig":
164172
return cls(
165173
model_path=model_path,
@@ -169,6 +177,7 @@ def basic_config(
169177
use_cfg_parallel=True,
170178
use_fsdp=True,
171179
offload_mode=offload_mode,
180+
offload_to_disk=offload_to_disk,
172181
)
173182

174183
def __post_init__(self):
@@ -196,6 +205,7 @@ def basic_config(
196205
device: str = "cuda",
197206
parallelism: int = 1,
198207
offload_mode: Optional[str] = None,
208+
offload_to_disk: bool = False,
199209
) -> "QwenImagePipelineConfig":
200210
return cls(
201211
model_path=model_path,
@@ -206,6 +216,7 @@ def basic_config(
206216
use_cfg_parallel=True,
207217
use_fsdp=True,
208218
offload_mode=offload_mode,
219+
offload_to_disk=offload_to_disk,
209220
)
210221

211222
def __post_init__(self):

diffsynth_engine/pipelines/base.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
self.offload_mode = None
4242
self.model_names = []
4343
self._offload_param_dict = {}
44+
self.offload_to_disk = False
4445

4546
@classmethod
4647
def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
@@ -228,17 +229,20 @@ def eval(self):
228229
model.eval()
229230
return self
230231

231-
def enable_cpu_offload(self, offload_mode: str):
232-
valid_offload_mode = ("cpu_offload", "sequential_cpu_offload")
232+
def enable_cpu_offload(self, offload_mode: str | None, offload_to_disk: bool = False):
233+
valid_offload_mode = ("cpu_offload", "sequential_cpu_offload", "disable", None)
233234
if offload_mode not in valid_offload_mode:
234235
raise ValueError(f"offload_mode must be one of {valid_offload_mode}, but got {offload_mode}")
235236
if self.device == "cpu" or self.device == "mps":
236237
logger.warning("must set an non cpu device for pipeline before calling enable_cpu_offload")
237238
return
238-
if offload_mode == "cpu_offload":
239+
if offload_mode is None or offload_mode == "disable":
240+
self._disable_offload()
241+
elif offload_mode == "cpu_offload":
239242
self._enable_model_cpu_offload()
240243
elif offload_mode == "sequential_cpu_offload":
241244
self._enable_sequential_cpu_offload()
245+
self.offload_to_disk = offload_to_disk
242246

243247
def _enable_model_cpu_offload(self):
244248
for model_name in self.model_names:
@@ -254,12 +258,21 @@ def _enable_sequential_cpu_offload(self):
254258
enable_sequential_cpu_offload(model, self.device)
255259
self.offload_mode = "sequential_cpu_offload"
256260

261+
def _disable_offload(self):
262+
self.offload_mode = None
263+
self._offload_param_dict = {}
264+
for model_name in self.model_names:
265+
model = getattr(self, model_name)
266+
if model is not None:
267+
model.to(self.device)
268+
257269
def enable_fp8_autocast(
258270
self, model_names: List[str], compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False
259271
):
260272
for model_name in model_names:
261273
model = getattr(self, model_name)
262274
if model is not None:
275+
model.to(dtype=torch.float8_e4m3fn)
263276
enable_fp8_autocast(model, compute_dtype, use_fp8_linear)
264277
self.fp8_autocast_enabled = True
265278

@@ -282,10 +295,27 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
282295
# load the needed models to device
283296
for model_name in load_model_names:
284297
model = getattr(self, model_name)
298+
if model is None:
299+
raise ValueError(
300+
f"model {model_name} is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True"
301+
)
285302
if model is not None and (p := next(model.parameters(), None)) is not None and p.device.type != self.device:
286303
model.to(self.device)
287304
# fresh the cuda cache
288305
empty_cache()
289306

307+
def model_lifecycle_finish(self, model_names: List[str] | None = None):
308+
if not self.offload_to_disk or self.offload_mode is None:
309+
return
310+
for model_name in model_names:
311+
model = getattr(self, model_name)
312+
del model
313+
if model_name in self._offload_param_dict:
314+
del self._offload_param_dict[model_name]
315+
setattr(self, model_name, None)
316+
print(f"model {model_name} has been deleted from memory")
317+
logger.info(f"model {model_name} has been deleted from memory")
318+
empty_cache()
319+
290320
def compile(self):
291321
raise NotImplementedError(f"{self.__class__.__name__} does not support compile")

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,19 @@ def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipe
208208
pipe.eval()
209209

210210
if config.offload_mode is not None:
211-
pipe.enable_cpu_offload(config.offload_mode)
211+
pipe.enable_cpu_offload(config.offload_mode, config.offload_to_disk)
212+
213+
if config.model_dtype == torch.float8_e4m3fn:
214+
pipe.dtype = torch.bfloat16 # compute dtype
215+
pipe.enable_fp8_autocast(
216+
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
217+
)
218+
219+
if config.encoder_dtype == torch.float8_e4m3fn:
220+
pipe.dtype = torch.bfloat16 # compute dtype
221+
pipe.enable_fp8_autocast(
222+
model_names=["encoder"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
223+
)
212224

213225
if config.parallelism > 1:
214226
pipe = ParallelWrapper(
@@ -393,6 +405,7 @@ def __call__(
393405
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(negative_prompt, 1, 4096)
394406
else:
395407
negative_prompt_embeds, negative_prompt_embeds_mask = None, None
408+
self.model_lifecycle_finish(["encoder"])
396409

397410
hide_progress = dist.is_initialized() and dist.get_rank() != 0
398411
for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
@@ -412,6 +425,7 @@ def __call__(
412425
# UI
413426
if progress_callback is not None:
414427
progress_callback(i, len(timesteps), "DENOISING")
428+
self.model_lifecycle_finish(["dit"])
415429
# Decode image
416430
self.load_models_to_device(["vae"])
417431
latents = rearrange(latents, "B C H W -> B C 1 H W")
@@ -423,5 +437,6 @@ def __call__(
423437
)
424438
image = self.vae_output_to_image(vae_output)
425439
# Offload all models
440+
self.model_lifecycle_finish(["vae"])
426441
self.load_models_to_device([])
427442
return image

0 commit comments

Comments
 (0)