Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions diffsynth_engine/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class BaseConfig:
vae_tile_stride: int | Tuple[int, int] = 256
device: str = "cuda"
offload_mode: Optional[str] = None
offload_to_disk: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉应该换个名字,看了一下实现我觉得跟offload是两个正交的能力,这个名字太困惑了。一般理解上的offload_to_disk是真的在磁盘上存点数据,然后通过一个方法能还原回来,但是现在实现上没有这么一个方法;而且我们的目的是节省一次性运行过程中的内存占用,应该也不需要这样的能力,结合model lifecycle换给名字会更好理解一些



@dataclass
Expand Down Expand Up @@ -62,11 +63,13 @@ def basic_config(
model_path: str | os.PathLike | List[str | os.PathLike],
device: str = "cuda",
offload_mode: Optional[str] = None,
offload_to_disk: bool = False,
) -> "SDPipelineConfig":
return cls(
model_path=model_path,
device=device,
offload_mode=offload_mode,
offload_to_disk=offload_to_disk,
)


Expand All @@ -87,11 +90,13 @@ def basic_config(
model_path: str | os.PathLike | List[str | os.PathLike],
device: str = "cuda",
offload_mode: Optional[str] = None,
offload_to_disk: bool = False,
) -> "SDXLPipelineConfig":
return cls(
model_path=model_path,
device=device,
offload_mode=offload_mode,
offload_to_disk=offload_to_disk,
)


Expand All @@ -116,13 +121,15 @@ def basic_config(
device: str = "cuda",
parallelism: int = 1,
offload_mode: Optional[str] = None,
offload_to_disk: bool = False,
) -> "FluxPipelineConfig":
return cls(
model_path=model_path,
device=device,
parallelism=parallelism,
use_fsdp=True,
offload_mode=offload_mode,
offload_to_disk=offload_to_disk,
)

def __post_init__(self):
Expand Down Expand Up @@ -160,6 +167,7 @@ def basic_config(
device: str = "cuda",
parallelism: int = 1,
offload_mode: Optional[str] = None,
offload_to_disk: bool = False,
) -> "WanPipelineConfig":
return cls(
model_path=model_path,
Expand All @@ -169,6 +177,7 @@ def basic_config(
use_cfg_parallel=True,
use_fsdp=True,
offload_mode=offload_mode,
offload_to_disk=offload_to_disk,
)

def __post_init__(self):
Expand Down Expand Up @@ -196,6 +205,7 @@ def basic_config(
device: str = "cuda",
parallelism: int = 1,
offload_mode: Optional[str] = None,
offload_to_disk: bool = False,
) -> "QwenImagePipelineConfig":
return cls(
model_path=model_path,
Expand All @@ -206,6 +216,7 @@ def basic_config(
use_cfg_parallel=True,
use_fsdp=True,
offload_mode=offload_mode,
offload_to_disk=offload_to_disk,
)

def __post_init__(self):
Expand Down
39 changes: 35 additions & 4 deletions diffsynth_engine/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
self.offload_mode = None
self.model_names = []
self._offload_param_dict = {}
self.offload_to_disk = False

@classmethod
def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
Expand Down Expand Up @@ -228,19 +229,23 @@ def eval(self):
model.eval()
return self

def enable_cpu_offload(self, offload_mode: str):
valid_offload_mode = ("cpu_offload", "sequential_cpu_offload")
def enable_cpu_offload(self, offload_mode: str | None, offload_to_disk:bool = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外是正交方法的话应该拆成另外一个函数,而不是作为这个函数的参数会更好一些

valid_offload_mode = ("cpu_offload", "sequential_cpu_offload", "disable", None)
if offload_mode not in valid_offload_mode:
raise ValueError(f"offload_mode must be one of {valid_offload_mode}, but got {offload_mode}")
if self.device == "cpu" or self.device == "mps":
logger.warning("must set an non cpu device for pipeline before calling enable_cpu_offload")
return
if offload_mode == "cpu_offload":
if offload_mode is None or offload_mode == "disable":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

都是在初始化的时候设置的,感觉disable这个选项没啥用呀

self._disable_offload()
elif offload_mode == "cpu_offload":
self._enable_model_cpu_offload()
elif offload_mode == "sequential_cpu_offload":
self._enable_sequential_cpu_offload()
self.offload_to_disk = offload_to_disk

def _enable_model_cpu_offload(self):

def _enable_model_cpu_offload(self):
for model_name in self.model_names:
model = getattr(self, model_name)
if model is not None:
Expand All @@ -253,13 +258,23 @@ def _enable_sequential_cpu_offload(self):
if model is not None:
enable_sequential_cpu_offload(model, self.device)
self.offload_mode = "sequential_cpu_offload"

def _disable_offload(self):
self.offload_mode = None
self._offload_param_dict = {}
for model_name in self.model_names:
model = getattr(self, model_name)
if model is not None:
model.to(self.device)


def enable_fp8_autocast(
self, model_names: List[str], compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False
):
for model_name in model_names:
model = getattr(self, model_name)
if model is not None:
model.to(device=self.device, dtype=torch.float8_e4m3fn)
enable_fp8_autocast(model, compute_dtype, use_fp8_linear)
self.fp8_autocast_enabled = True

Expand All @@ -282,10 +297,26 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
# load the needed models to device
for model_name in load_model_names:
model = getattr(self, model_name)
if model is None:
raise ValueError(f"model {model_name} is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True")
if model is not None and (p := next(model.parameters(), None)) is not None and p.device.type != self.device:
model.to(self.device)
# fresh the cuda cache
empty_cache()

def model_lifecycle_finish(self, model_names: List[str] | None = None):
if not self.offload_to_disk or self.offload_mode is None:
return
for model_name in model_names:
model = getattr(self, model_name)
del model
if model_name in self._offload_param_dict:
del self._offload_param_dict[model_name]
setattr(self, model_name, None)
print(f"model {model_name} has been deleted from memory")
logger.info(f"model {model_name} has been deleted from memory")
empty_cache()


def compile(self):
raise NotImplementedError(f"{self.__class__.__name__} does not support compile")
17 changes: 16 additions & 1 deletion diffsynth_engine/pipelines/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,19 @@ def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipe
pipe.eval()

if config.offload_mode is not None:
pipe.enable_cpu_offload(config.offload_mode)
pipe.enable_cpu_offload(config.offload_mode, config.offload_to_disk)

if config.model_dtype == torch.float8_e4m3fn:
pipe.dtype = torch.bfloat16 # compute dtype
pipe.enable_fp8_autocast(
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
)

if config.encoder_dtype == torch.float8_e4m3fn:
pipe.dtype = torch.bfloat16 # compute dtype
pipe.enable_fp8_autocast(
model_names=["encoder"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
)

if config.parallelism > 1:
pipe = ParallelWrapper(
Expand Down Expand Up @@ -393,6 +405,7 @@ def __call__(
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(negative_prompt, 1, 4096)
else:
negative_prompt_embeds, negative_prompt_embeds_mask = None, None
self.model_lifecycle_finish(["encoder"])

hide_progress = dist.is_initialized() and dist.get_rank() != 0
for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
Expand All @@ -412,6 +425,7 @@ def __call__(
# UI
if progress_callback is not None:
progress_callback(i, len(timesteps), "DENOISING")
self.model_lifecycle_finish(["dit"])
# Decode image
self.load_models_to_device(["vae"])
latents = rearrange(latents, "B C H W -> B C 1 H W")
Expand All @@ -423,5 +437,6 @@ def __call__(
)
image = self.vae_output_to_image(vae_output)
# Offload all models
self.model_lifecycle_finish(["vae"])
self.load_models_to_device([])
return image