2525from diffsynth_engine .utils import logging
2626from diffsynth_engine .utils .fp8_linear import enable_fp8_linear
2727from diffsynth_engine .utils .download import fetch_model
28+ from diffsynth_engine .utils .platform import empty_cache
2829
2930logger = logging .get_logger (__name__ )
3031
@@ -546,6 +547,7 @@ def predict_noise(
546547 current_step = current_step ,
547548 total_step = total_step ,
548549 )
550+ self .load_models_to_device (["dit" ])
549551 noise_pred = self .dit (
550552 hidden_states = latents ,
551553 timestep = timestep ,
@@ -570,15 +572,14 @@ def prepare_latents(
570572 ):
571573 # Prepare scheduler
572574 if input_image is not None :
575+ self .load_models_to_device (["vae_encoder" ])
573576 total_steps = num_inference_steps
574577 sigmas , timesteps = self .noise_scheduler .schedule (
575578 total_steps , mu = mu , sigma_min = 1 / total_steps , sigma_max = 1.0
576579 )
577580 t_start = max (total_steps - int (num_inference_steps * denoising_strength ), 1 )
578581 sigma_start , sigmas = sigmas [t_start - 1 ], sigmas [t_start - 1 :]
579582 timesteps = timesteps [t_start - 1 :]
580-
581- self .load_models_to_device (["vae_encoder" ])
582583 noise = latents
583584 image = self .preprocess_image (input_image ).to (device = self .device , dtype = self .dtype )
584585 latents = self .encode_image (image )
@@ -593,6 +594,7 @@ def prepare_latents(
593594 return init_latents , latents , sigmas , timesteps
594595
595596 def prepare_masked_latent (self , image : Image .Image , mask : Image .Image | None , height : int , width : int ):
597+ self .load_models_to_device (["vae_encoder" ])
596598 if mask is None :
597599 image = image .resize ((width , height ))
598600 image = self .preprocess_image (image ).to (device = self .device , dtype = self .dtype )
@@ -637,6 +639,8 @@ def predict_multicontrolnet(
637639 total_step : int ,
638640 ):
639641 double_block_output_results , single_block_output_results = None , None
642+ if len (controlnet_params ) > 0 :
643+ self .load_models_to_device ([])
640644 for param in controlnet_params :
641645 current_scale = param .scale
642646 if not (
@@ -645,6 +649,9 @@ def predict_multicontrolnet(
645649 # if current_step is not in the control range
646650 # skip thie controlnet
647651 continue
652+ if self .offload_mode == "sequential_cpu_offload" or self .offload_mode == "cpu_offload" :
653+ empty_cache ()
654+ param .model .to (self .device )
648655 double_block_output , single_block_output = param .model (
649656 latents ,
650657 param .image ,
@@ -656,6 +663,9 @@ def predict_multicontrolnet(
656663 image_ids ,
657664 text_ids ,
658665 )
666+ if self .offload_mode == "sequential_cpu_offload" or self .offload_mode == "cpu_offload" :
667+ empty_cache ()
668+ param .model .to ("cpu" )
659669 double_block_output_results = accumulate (double_block_output_results , double_block_output )
660670 single_block_output_results = accumulate (single_block_output_results , single_block_output )
661671 return double_block_output_results , single_block_output_results
@@ -741,7 +751,7 @@ def __call__(
741751 )
742752
743753 # Denoise
744- self .load_models_to_device (["dit" ])
754+ self .load_models_to_device ([])
745755 for i , timestep in enumerate (tqdm (timesteps )):
746756 timestep = timestep .unsqueeze (0 ).to (dtype = self .dtype )
747757 noise_pred = self .predict_noise_with_cfg (
0 commit comments