Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions diffsynth_engine/models/wan/wan_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,10 @@ def forward(
clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
y: Optional[torch.Tensor] = None, # vae_encoder(img)
):
use_cfg = x.shape[0] > 1
with (
gguf_inference(),
cfg_parallel((x, context, timestep, clip_feature, y)),
cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg),
):
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
Expand Down Expand Up @@ -365,7 +366,7 @@ def forward(
x = self.head(x, t)
(x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
x = self.unpatchify(x, (f, h, w))
(x,) = cfg_parallel_unshard((x,))
(x,) = cfg_parallel_unshard((x,), use_cfg=use_cfg)
return x

@classmethod
Expand Down
29 changes: 14 additions & 15 deletions diffsynth_engine/models/wan/wan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def convert(self, state_dict):
class WanVideoVAE(PreTrainedModel):
converter = WanVideoVAEStateDictConverter()

def __init__(self, z_dim=16, parallelism: int = 1, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
def __init__(self, z_dim=16, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
super().__init__()

mean = [
Expand Down Expand Up @@ -561,12 +561,11 @@ def __init__(self, z_dim=16, parallelism: int = 1, device: str = "cuda:0", dtype
# init model
self.model = VideoVAE(z_dim=z_dim).eval().requires_grad_(False)
self.upsampling_factor = 8
self.parallelism = parallelism

@classmethod
def from_state_dict(cls, state_dict, parallelism=1, device="cuda:0", dtype=torch.float32) -> "WanVideoVAE":
def from_state_dict(cls, state_dict, device="cuda:0", dtype=torch.float32) -> "WanVideoVAE":
with no_init_weights():
model = torch.nn.utils.skip_init(cls, parallelism=parallelism, device=device, dtype=dtype)
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
model.load_state_dict(state_dict, assign=True)
model.to(device=device, dtype=dtype, non_blocking=True)
return model
Expand Down Expand Up @@ -607,7 +606,7 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))

data_device = device if self.parallelism > 1 else "cpu"
data_device = device if dist.is_initialized() else "cpu"
computation_device = device

out_T = T * 4 - 3
Expand All @@ -622,9 +621,9 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
device=data_device,
)

hide_progress_bar = self.parallelism > 1 and dist.get_rank() != 0
for i, (h, h_, w, w_) in enumerate(tqdm(tasks, desc="VAE DECODING", disable=hide_progress_bar)):
if self.parallelism > 1 and (i % dist.get_world_size() != dist.get_rank()):
hide_progress = dist.is_initialized() and dist.get_rank() != 0
for i, (h, h_, w, w_) in enumerate(tqdm(tasks, desc="VAE DECODING", disable=hide_progress)):
if dist.is_initialized() and (i % dist.get_world_size() != dist.get_rank()):
continue
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
Expand Down Expand Up @@ -654,11 +653,11 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += mask
if progress_callback is not None and not hide_progress_bar:
if progress_callback is not None and not hide_progress:
progress_callback(i + 1, len(tasks), "VAE DECODING")
if progress_callback is not None and not hide_progress_bar:
if progress_callback is not None and not hide_progress:
progress_callback(len(tasks), len(tasks), "VAE DECODING")
if self.parallelism > 1:
if dist.is_initialized():
dist.all_reduce(values)
dist.all_reduce(weight)
values = values / weight
Expand All @@ -681,7 +680,7 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))

data_device = device if self.parallelism > 1 else "cpu"
data_device = device if dist.is_initialized() else "cpu"
computation_device = device

out_T = (T + 3) // 4
Expand All @@ -696,9 +695,9 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
device=data_device,
)

hide_progress_bar = self.parallelism > 1 and dist.get_rank() != 0
hide_progress_bar = dist.is_initialized() and dist.get_rank() != 0
for i, (h, h_, w, w_) in enumerate(tqdm(tasks, desc="VAE ENCODING", disable=hide_progress_bar)):
if self.parallelism > 1 and (i % dist.get_world_size() != dist.get_rank()):
if dist.is_initialized() and (i % dist.get_world_size() != dist.get_rank()):
continue
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
Expand Down Expand Up @@ -732,7 +731,7 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
progress_callback(i + 1, len(tasks), "VAE ENCODING")
if progress_callback is not None and not hide_progress_bar:
progress_callback(len(tasks), len(tasks), "VAE ENCODING")
if self.parallelism > 1:
if dist.is_initialized():
dist.all_reduce(values)
dist.all_reduce(weight)
values = values / weight
Expand Down
2 changes: 0 additions & 2 deletions examples/flux_parallel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import torch.multiprocessing as mp
from diffsynth_engine import fetch_model, FluxImagePipeline


if __name__ == "__main__":
mp.set_start_method("spawn")
model_path = fetch_model("muse/flux-with-vae", path="flux1-dev-with-vae.safetensors")
pipe = FluxImagePipeline.from_pretrained(model_path, parallelism=4, offload_mode="cpu_offload")
image = pipe(prompt="a cat", seed=42)
Expand Down
4 changes: 1 addition & 3 deletions examples/wan_flf_to_video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch.multiprocessing as mp
from PIL import Image

from diffsynth_engine.pipelines import WanVideoPipeline, WanModelConfig
Expand All @@ -7,9 +6,8 @@


if __name__ == "__main__":
mp.set_start_method("spawn")
config = WanModelConfig(
model_path=fetch_model("muse/wan2.1-flf2v-14b-720p-bf16", path="dit.safetensors"),
model_path=fetch_model("MusePublic/wan2.1-flf2v-14b-720p-bf16", path="dit.safetensors"),
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
image_encoder_path=fetch_model(
Expand Down
4 changes: 1 addition & 3 deletions examples/wan_image_to_video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch.multiprocessing as mp
from PIL import Image

from diffsynth_engine.pipelines import WanVideoPipeline, WanModelConfig
Expand All @@ -7,9 +6,8 @@


if __name__ == "__main__":
mp.set_start_method("spawn")
config = WanModelConfig(
model_path=fetch_model("muse/wan2.1-i2v-14b-480p-bf16", path="dit.safetensors"),
model_path=fetch_model("MusePublic/wan2.1-i2v-14b-480p-bf16", path="dit.safetensors"),
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
image_encoder_path=fetch_model(
Expand Down
3 changes: 0 additions & 3 deletions examples/wan_lora.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import torch.multiprocessing as mp

from diffsynth_engine.pipelines import WanVideoPipeline, WanModelConfig
from diffsynth_engine.utils.download import fetch_model
from diffsynth_engine.utils.video import save_video


if __name__ == "__main__":
mp.set_start_method("spawn")
config = WanModelConfig(
model_path=fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors"),
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
Expand Down
5 changes: 1 addition & 4 deletions examples/wan_text_to_video.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import torch.multiprocessing as mp

from diffsynth_engine.pipelines import WanVideoPipeline, WanModelConfig
from diffsynth_engine.utils.download import fetch_model
from diffsynth_engine.utils.video import save_video


if __name__ == "__main__":
mp.set_start_method("spawn")
config = WanModelConfig(
model_path=fetch_model("muse/wan2.1-t2v-14b-bf16", path="dit.safetensors"),
model_path=fetch_model("MusePublic/wan2.1-14b-t2v", path="dit.safetensors"),
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
use_fsdp=True,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_models/wan/test_wan_vae_parallel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.multiprocessing as mp
import unittest
import numpy as np

Expand All @@ -13,7 +12,6 @@
class TestWanVAEParallel(VideoTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn")
cls._vae_model_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
loaded_state_dict = load_file(cls._vae_model_path)
vae = WanVideoVAE.from_state_dict(loaded_state_dict, parallelism=4)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_pipelines/test_wan_video_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch.multiprocessing as mp
import unittest

from tests.common.test_case import VideoTestCase
Expand All @@ -9,7 +8,6 @@
class TestWanVideoTP(VideoTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn")
config = WanModelConfig(
model_path=fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors"),
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
Expand Down