Skip to content

unittest for wan parallel #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/test_models/wan/test_wan_vae_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import torch.multiprocessing as mp
import unittest
import numpy as np

from diffsynth_engine.utils.loader import load_file
from diffsynth_engine.utils.parallel import ParallelModel
from diffsynth_engine.models.wan.wan_vae import WanVideoVAE
from diffsynth_engine import fetch_model
from tests.common.test_case import VideoTestCase


class TestWanVAEParallel(VideoTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn")
cls._vae_model_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
loaded_state_dict = load_file(cls._vae_model_path)
vae = WanVideoVAE.from_state_dict(loaded_state_dict, parallelism=4)
cls.vae = ParallelModel(vae, cfg_degree=1, sp_ulysses_degree=4, sp_ring_degree=1, tp_degree=1)
cls._input_video = cls.get_input_video("astronaut_320_320.mp4")

@classmethod
def tearDownClass(cls):
del cls.vae

def test_encode_parallel(self):
expected_tensor = self.get_expect_tensor("wan/wan_vae.safetensors")
expected = expected_tensor["encoded"]
video_frames = [
torch.tensor(np.array(frame, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
for frame in self._input_video.frames
]
video_tensor = torch.stack(video_frames, dim=2)
with torch.no_grad():
result = self.vae.encode(video_tensor, device="cuda", tiled=True).cpu()
self.assertTensorEqual(result, expected)

def test_decode_parallel(self):
expected_tensor = self.get_expect_tensor("wan/wan_vae.safetensors")
latent_tensor, expected = expected_tensor["encoded"], expected_tensor["decoded"]
with torch.no_grad():
result = self.vae.decode(latent_tensor, device="cuda", tiled=True)[0].cpu()
self.assertTensorEqual(result, expected)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch.multiprocessing as mp
import unittest

from tests.common.test_case import VideoTestCase
Expand All @@ -8,13 +9,18 @@
class TestWanVideoTP(VideoTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn")
config = WanModelConfig(
model_path=fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors"),
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
)
cls.pipe = WanVideoPipeline.from_pretrained(config, parallelism=4, use_cfg_parallel=True)

@classmethod
def tearDownClass(cls):
del cls.pipe

def test_txt2video(self):
video = self.pipe(
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
Expand All @@ -23,8 +29,7 @@ def test_txt2video(self):
width=480,
height=480,
)
self.save_video(video, "wan_tp_t2v.mp4")
del self.pipe
self.save_video(video, "wan_t2v.mp4")


if __name__ == "__main__":
Expand Down