Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions examples/parallelism/run_wan_cp_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import sys

sys.path.append("..")

import time

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

from diffusers import WanPipeline, WanTransformer3DModel
from diffusers.utils import export_to_video
from utils import (
cachify,
get_args,
maybe_destroy_distributed,
maybe_init_distributed,
strify,
)

import cache_dit


def run_pipe(args, pipe, warmup: bool = False):
prompt = "A cat walks on the grass, realistic"
negative_prompt = "Bright tones, overexposed, static, blurred details, "
"subtitles, style, works, paintings, images, static, overall gray, "
"worst quality, low quality, JPEG compression residue, ugly, incomplete, "
"extra fingers, poorly drawn hands, poorly drawn faces, deformed, "
"disfigured, misshapen limbs, fused fingers, still picture, messy "
"background, three legs, many people in the background, walking backwards"

seed = 1234
generator = torch.Generator(device="cpu").manual_seed(seed)

num_inference_steps = args.steps if not warmup else 3
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=args.height,
width=args.width,
num_frames=49,
guidance_scale=5.0,
generator=generator,
num_inference_steps=num_inference_steps,
).frames[0]
return output


def main():
args = get_args()
print(args)

rank, device = maybe_init_distributed(args)

model_id = os.environ.get(
"WAN_2_2_DIR",
# "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
)

pipe = WanPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
)

if args.cache or args.parallel_type is not None:
cachify(args, pipe)

if args.cpu_offload:
pipe.enable_model_cpu_offload(device=device)
else:
pipe.to(device)

if args.vae_tiling:
pipe.vae.enable_tiling(
tile_sample_min_height=int(args.height / 2 * 3),
tile_sample_min_width=int(args.width / 2 * 3),
tile_sample_stride_height=int(args.height / 2),
tile_sample_stride_width=int(args.width / 2),
)

assert isinstance(pipe.transformer, WanTransformer3DModel)

pipe.set_progress_bar_config(disable=rank != 0)

# warmup
_ = run_pipe(args, pipe, warmup=True)

start = time.time()
video = run_pipe(args, pipe)
end = time.time()

if rank == 0:
cache_dit.summary(pipe)

time_cost = end - start
save_path = f"wan.{strify(args, pipe)}.mp4"
print(f"Time cost: {time_cost:.2f}s")
print(f"Saving image to {save_path}")
export_to_video(video, save_path, fps=16)

maybe_destroy_distributed()


if __name__ == "__main__":
main()
9 changes: 6 additions & 3 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_args(
parser.add_argument("--cache", action="store_true", default=False)
parser.add_argument("--compile", action="store_true", default=False)
parser.add_argument("--fuse-lora", action="store_true", default=False)
parser.add_argument("--steps", type=int, default=None)
parser.add_argument("--steps", type=int, default=35)
parser.add_argument("--Fn", type=int, default=8)
parser.add_argument("--Bn", type=int, default=0)
parser.add_argument("--rdt", type=float, default=0.08)
Expand All @@ -41,8 +41,8 @@ def get_args(
)
parser.add_argument("--taylorseer", action="store_true", default=False)
parser.add_argument("--taylorseer-order", "-order", type=int, default=1)
parser.add_argument("--height", type=int, default=None)
parser.add_argument("--width", type=int, default=None)
parser.add_argument("--height", type=int, default=480)
parser.add_argument("--width", type=int, default=832)
parser.add_argument("--quantize", "-q", action="store_true", default=False)
# float8, float8_weight_only, int8, int8_weight_only, int4, int4_weight_only
parser.add_argument(
Expand Down Expand Up @@ -80,9 +80,12 @@ def get_args(
# Based on this fix: https://github.com/huggingface/diffusers/pull/12563
"native", # native pytorch attention: sdpa
"_native_cudnn",
"_native_npu"
],
)
parser.add_argument("--perf", action="store_true", default=False)
parser.add_argument("--vae-tiling", action="store_true", default=False)
parser.add_argument("--cpu-offload", action="store_true", default=False)
return parser.parse_args() if parse else parser


Expand Down