|
| 1 | +import os |
| 2 | +import sys |
| 3 | + |
| 4 | +sys.path.append("..") |
| 5 | + |
| 6 | +import time |
| 7 | +import torch |
| 8 | +from diffusers import ( |
| 9 | + LTXConditionPipeline, |
| 10 | + LTXLatentUpsamplePipeline, |
| 11 | + AutoencoderKLLTXVideo, |
| 12 | +) |
| 13 | +from diffusers.quantizers import PipelineQuantizationConfig |
| 14 | +from diffusers.utils import export_to_video |
| 15 | +from utils import ( |
| 16 | + cachify, |
| 17 | + get_args, |
| 18 | + maybe_destroy_distributed, |
| 19 | + maybe_init_distributed, |
| 20 | + strify, |
| 21 | +) |
| 22 | +import cache_dit |
| 23 | + |
| 24 | +# NOTE: Please use `--attn flash` for LTXVideo with context parallelism, |
| 25 | +# otherwise, it may raise attention mask not supported error. |
| 26 | + |
| 27 | +args = get_args() |
| 28 | +print(args) |
| 29 | + |
| 30 | +rank, device = maybe_init_distributed(args) |
| 31 | + |
| 32 | +pipe = LTXConditionPipeline.from_pretrained( |
| 33 | + os.environ.get("LTX_VIDEO_DIR", "Lightricks/LTX-Video-0.9.7-dev"), |
| 34 | + torch_dtype=torch.bfloat16, |
| 35 | + quantization_config=PipelineQuantizationConfig( |
| 36 | + quant_backend="bitsandbytes_4bit", |
| 37 | + quant_kwargs={ |
| 38 | + "load_in_4bit": True, |
| 39 | + "bnb_4bit_quant_type": "nf4", |
| 40 | + "bnb_4bit_compute_dtype": torch.bfloat16, |
| 41 | + }, |
| 42 | + components_to_quantize=["text_encoder", "transformer"], |
| 43 | + ), |
| 44 | +) |
| 45 | + |
| 46 | +pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( |
| 47 | + os.environ.get( |
| 48 | + "LTX_UPSCALER_DIR", "Lightricks/ltxv-spatial-upscaler-0.9.7" |
| 49 | + ), |
| 50 | + vae=pipe.vae, |
| 51 | + torch_dtype=torch.bfloat16, |
| 52 | +) |
| 53 | + |
| 54 | +pipe.to(device) |
| 55 | +pipe_upsample.to(device) |
| 56 | +assert isinstance(pipe.vae, AutoencoderKLLTXVideo) |
| 57 | +assert isinstance(pipe_upsample.vae, AutoencoderKLLTXVideo) |
| 58 | + |
| 59 | +pipe.set_progress_bar_config(disable=rank != 0) |
| 60 | +pipe_upsample.set_progress_bar_config(disable=rank != 0) |
| 61 | + |
| 62 | +if args.cache or args.parallel_type is not None: |
| 63 | + cachify(args, pipe) |
| 64 | + |
| 65 | + |
| 66 | +def round_to_nearest_resolution_acceptable_by_vae(height, width): |
| 67 | + height = height - (height % pipe.vae_spatial_compression_ratio) |
| 68 | + width = width - (width % pipe.vae_spatial_compression_ratio) |
| 69 | + return height, width |
| 70 | + |
| 71 | + |
| 72 | +prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." |
| 73 | +negative_prompt = ( |
| 74 | + "worst quality, inconsistent motion, blurry, jittery, distorted" |
| 75 | +) |
| 76 | +expected_height, expected_width = 512, 704 |
| 77 | +downscale_factor = 2 / 3 |
| 78 | +num_frames = 49 |
| 79 | + |
| 80 | +# Part 1. Generate video at smaller resolution |
| 81 | +downscaled_height, downscaled_width = int( |
| 82 | + expected_height * downscale_factor |
| 83 | +), int(expected_width * downscale_factor) |
| 84 | +downscaled_height, downscaled_width = ( |
| 85 | + round_to_nearest_resolution_acceptable_by_vae( |
| 86 | + downscaled_height, downscaled_width |
| 87 | + ) |
| 88 | +) |
| 89 | + |
| 90 | + |
| 91 | +def run_pipe(warmup: bool = False): |
| 92 | + |
| 93 | + latents = pipe( |
| 94 | + conditions=None, |
| 95 | + prompt=prompt, |
| 96 | + negative_prompt=negative_prompt, |
| 97 | + width=downscaled_width, |
| 98 | + height=downscaled_height, |
| 99 | + num_frames=num_frames, |
| 100 | + num_inference_steps=30 if not warmup else 4, |
| 101 | + generator=torch.Generator("cpu").manual_seed(0), |
| 102 | + output_type="latent", |
| 103 | + ).frames |
| 104 | + |
| 105 | + # Part 2. Upscale generated video using latent upsampler with fewer inference steps |
| 106 | + # The available latent upsampler upscales the height/width by 2x |
| 107 | + upscaled_height, upscaled_width = ( |
| 108 | + downscaled_height * 2, |
| 109 | + downscaled_width * 2, |
| 110 | + ) |
| 111 | + upscaled_latents = pipe_upsample( |
| 112 | + latents=latents, output_type="latent" |
| 113 | + ).frames |
| 114 | + |
| 115 | + if warmup: |
| 116 | + return None |
| 117 | + |
| 118 | + # Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended) |
| 119 | + video = pipe( |
| 120 | + prompt=prompt, |
| 121 | + negative_prompt=negative_prompt, |
| 122 | + width=upscaled_width, |
| 123 | + height=upscaled_height, |
| 124 | + num_frames=num_frames, |
| 125 | + denoise_strength=0.4, # Effectively, 4 inference steps out of 10 |
| 126 | + num_inference_steps=10, |
| 127 | + latents=upscaled_latents, |
| 128 | + decode_timestep=0.05, |
| 129 | + image_cond_noise_scale=0.025, |
| 130 | + generator=torch.Generator("cpu").manual_seed(0), |
| 131 | + output_type="pil", |
| 132 | + ).frames[0] |
| 133 | + return video |
| 134 | + |
| 135 | + |
| 136 | +# warmup |
| 137 | +_ = run_pipe(warmup=True) |
| 138 | + |
| 139 | +start = time.time() |
| 140 | +video = run_pipe() |
| 141 | +end = time.time() |
| 142 | +stats = cache_dit.summary(pipe) |
| 143 | + |
| 144 | +if rank == 0: |
| 145 | + # Part 4. Downscale the video to the expected resolution |
| 146 | + video = [frame.resize((expected_width, expected_height)) for frame in video] |
| 147 | + |
| 148 | + time_cost = end - start |
| 149 | + save_path = f"ltx-video.{strify(args, stats)}.mp4" |
| 150 | + print(f"Time cost: {time_cost:.2f}s") |
| 151 | + print(f"Saving video to {save_path}") |
| 152 | + export_to_video(video, save_path, fps=8) |
| 153 | + |
| 154 | +maybe_destroy_distributed() |
0 commit comments