Skip to content

Commit 198ea71

Browse files
authored
Merge pull request Eco-Sphere#1 from TmacAaron/dev
Add Support For Flux
2 parents 4a3d256 + d4c1684 commit 198ea71

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+991
-265
lines changed

bench/bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def init_flux_pipe(args: argparse.Namespace) -> FluxPipeline:
6868
DBCacheConfig,
6969
TaylorSeerCalibratorConfig,
7070
)
71-
from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
71+
from cache_dit.caching.patch_functors import FluxPatchFunctor
7272

7373
cache_dit.enable_cache(
7474
# BlockAdapter & forward pattern

bench/perf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def main():
100100
DBCacheConfig,
101101
TaylorSeerCalibratorConfig,
102102
)
103-
from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
103+
from cache_dit.caching.patch_functors import FluxPatchFunctor
104104

105105
if args.cache_config is None:
106106

docs/DBCache.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
1717
|:---:|:---:|:---:|:---:|:---:|:---:|
1818
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
19-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
19+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=140px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=140px>|
2020
|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
2121
|27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
22-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
22+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=140px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=140px>|
2323

2424
<div align="center">
2525
<p align="center">
@@ -79,7 +79,7 @@ cache_dit.enable_cache(
7979
|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
8080
|:---:|:---:|:---:|:---:|:---:|:---:|
8181
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
82-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
82+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=140px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=140px>|
8383

8484
## ⚡️Hybrid Cache CFG
8585

docs/User_Guide.md

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,13 @@ For any PATTERN not in {0...5}, we introduced the simple abstract concept of **P
314314

315315
![](https://github.com/vipshop/cache-dit/raw/main/assets/patch-functor.png)
316316

317-
Some Patch functors have already been provided in cache-dit: [📚HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [📚ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc. After implementing Patch Functor, users need to set the `patch_functor` property of **BlockAdapter**.
317+
Some Patch functors have already been provided in cache-dit: [📚HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/patch_functors/functor_hidream.py), [📚ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/patch_functors/functor_chroma.py), etc. After implementing Patch Functor, users need to set the `patch_functor` property of **BlockAdapter**.
318318

319319
```python
320320
@BlockAdapterRegistry.register("HiDream")
321321
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
322322
from diffusers import HiDreamImageTransformer2DModel
323-
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
323+
from cache_dit.caching.patch_functors import HiDreamPatchFunctor
324324

325325
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
326326
return BlockAdapter(
@@ -431,6 +431,8 @@ You can set `details` param as `True` to show more details of cache stats. (mark
431431
- **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
432432
- **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
433433

434+
![](https://github.com/vipshop/cache-dit/raw/main/assets/dbcache-fnbn-v1.png)
435+
434436
```python
435437
import cache_dit
436438
from diffusers import FluxPipeline
@@ -469,6 +471,17 @@ cache_dit.enable_cache(
469471
|:---:|:---:|:---:|:---:|:---:|:---:|
470472
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
471473
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=140px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=140px>|
474+
|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
475+
|27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
476+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=140px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=140px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=140px>|
477+
478+
<div align="center">
479+
<p align="center">
480+
DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
481+
</p>
482+
</div>
483+
484+
These case studies demonstrate that even with relatively high thresholds (such as 0.12, 0.15, 0.2, etc.) under the DBCache **F12B12** or **F8B16** configuration, the detailed texture of the kitten's fur, colored cloth, and the clarity of text can still be preserved. This suggests that users can leverage DBCache to effectively balance performance and precision in their workflows!
472485

473486
## ⚡️DBPrune: Dynamic Block Prune
474487

@@ -780,7 +793,6 @@ This function seamlessly integrates with both standard diffusion pipelines and c
780793
- **pipe_or_adapter**(`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
781794
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
782795
For example: `cache_dit.enable_cache(FluxPipeline(...))`.
783-
Please check https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md for the usage of BlockAdapter.
784796

785797
- **cache_config**(`DBCacheConfig`, *required*, defaults to DBCacheConfig()):
786798
Basic DBCache config for cache context, defaults to DBCacheConfig(). The configurable parameters are listed below:
@@ -845,4 +857,4 @@ This function seamlessly integrates with both standard diffusion pipelines and c
845857
it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
846858

847859
- **kwargs** (`dict`, *optional*, defaults to {}):
848-
Other cache context keyword arguments. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py for more details.
860+
Other cache context keyword arguments. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/cache_contexts/cache_context.py for more details.

docs/community_optimization.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,13 @@ For any pattern not included in CacheDiT, use the Patch Functor to convert the p
122122

123123
![](https://github.com/vipshop/cache-dit/raw/main/assets/patch-functor.png)
124124

125-
Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc.
125+
Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/patch_functors/functor_chroma.py), etc.
126126

127127
```python
128128
@BlockAdapterRegistry.register("HiDream")
129129
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
130130
from diffusers import HiDreamImageTransformer2DModel
131-
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
131+
from cache_dit.caching.patch_functors import HiDreamPatchFunctor
132132

133133
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
134134
return BlockAdapter(

examples/parallelism/run.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
export HCCL_OP_EXPANSION_MODE="AIV"
2+
export TASK_QUEUE_ENABLE=2
3+
export CPU_AFFINITY_CONF=2
4+
5+
export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2:$LD_PRELOAD
6+
7+
8+
FLUX_DIR=/home/weights/FLUX.1-dev/ torchrun --nproc_per_node=1 run_flux_cp_npu.py --attn "_native_npu" --height 1024 --width 1024
9+
10+
# WAN_2_2_DIR=/home/weights/Wan2.1-T2V-14B-Diffusers/ torchrun --nproc_per_node=8 run_wan_cp_npu.py --attn "_native_npu" --height 1024 --width 1024 --steps 10 --parallel ulysses --vae-dp
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
import sys
3+
4+
sys.path.append("..")
5+
6+
import time
7+
import torch
8+
import torch_npu
9+
from torch_npu.contrib import transfer_to_npu
10+
11+
from diffusers import (
12+
FluxPipeline,
13+
FluxTransformer2DModel,
14+
PipelineQuantizationConfig,
15+
)
16+
from utils import (
17+
get_args,
18+
strify,
19+
cachify,
20+
maybe_init_distributed,
21+
maybe_destroy_distributed,
22+
)
23+
import cache_dit
24+
from cache_dit.npu_optim import npu_optimize
25+
26+
27+
npu_optimize([
28+
"npu_fast_gelu",
29+
"npu_rms_norm",
30+
"npu_layer_norm_eval",
31+
"npu_rotary_mul",
32+
])
33+
34+
args = get_args()
35+
print(args)
36+
37+
rank, device = maybe_init_distributed(args)
38+
39+
pipe: FluxPipeline = FluxPipeline.from_pretrained(
40+
os.environ.get(
41+
"FLUX_DIR",
42+
"black-forest-labs/FLUX.1-dev",
43+
),
44+
torch_dtype=torch.bfloat16,
45+
).to("cuda")
46+
47+
if args.cache or args.parallel_type is not None:
48+
cachify(args, pipe)
49+
50+
assert isinstance(pipe.transformer, FluxTransformer2DModel)
51+
52+
pipe.set_progress_bar_config(disable=rank != 0)
53+
54+
55+
def run_pipe(pipe: FluxPipeline):
56+
image = pipe(
57+
"A cat holding a sign that says hello world",
58+
height=1024 if args.height is None else args.height,
59+
width=1024 if args.width is None else args.width,
60+
num_inference_steps=28 if args.steps is None else args.steps,
61+
generator=torch.Generator("cpu").manual_seed(0),
62+
).images[0]
63+
return image
64+
65+
66+
if args.compile:
67+
cache_dit.set_compile_configs()
68+
pipe.transformer = torch.compile(pipe.transformer)
69+
70+
# warmup
71+
_ = run_pipe(pipe)
72+
73+
start = time.time()
74+
image = run_pipe(pipe)
75+
end = time.time()
76+
77+
if rank == 0:
78+
cache_dit.summary(pipe)
79+
80+
time_cost = end - start
81+
save_path = f"flux.{strify(args, pipe)}.png"
82+
print(f"Time cost: {time_cost:.2f}s")
83+
print(f"Saving image to {save_path}")
84+
image.save(save_path)
85+
86+
maybe_destroy_distributed()
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import os
2+
import sys
3+
4+
sys.path.append("..")
5+
6+
import time
7+
import torch
8+
from diffusers import (
9+
LTXConditionPipeline,
10+
LTXLatentUpsamplePipeline,
11+
AutoencoderKLLTXVideo,
12+
)
13+
from diffusers.quantizers import PipelineQuantizationConfig
14+
from diffusers.utils import export_to_video
15+
from utils import (
16+
cachify,
17+
get_args,
18+
maybe_destroy_distributed,
19+
maybe_init_distributed,
20+
strify,
21+
)
22+
import cache_dit
23+
24+
# NOTE: Please use `--attn flash` for LTXVideo with context parallelism,
25+
# otherwise, it may raise attention mask not supported error.
26+
27+
args = get_args()
28+
print(args)
29+
30+
rank, device = maybe_init_distributed(args)
31+
32+
pipe = LTXConditionPipeline.from_pretrained(
33+
os.environ.get("LTX_VIDEO_DIR", "Lightricks/LTX-Video-0.9.7-dev"),
34+
torch_dtype=torch.bfloat16,
35+
quantization_config=PipelineQuantizationConfig(
36+
quant_backend="bitsandbytes_4bit",
37+
quant_kwargs={
38+
"load_in_4bit": True,
39+
"bnb_4bit_quant_type": "nf4",
40+
"bnb_4bit_compute_dtype": torch.bfloat16,
41+
},
42+
components_to_quantize=["text_encoder", "transformer"],
43+
),
44+
)
45+
46+
pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
47+
os.environ.get(
48+
"LTX_UPSCALER_DIR", "Lightricks/ltxv-spatial-upscaler-0.9.7"
49+
),
50+
vae=pipe.vae,
51+
torch_dtype=torch.bfloat16,
52+
)
53+
54+
pipe.to(device)
55+
pipe_upsample.to(device)
56+
assert isinstance(pipe.vae, AutoencoderKLLTXVideo)
57+
assert isinstance(pipe_upsample.vae, AutoencoderKLLTXVideo)
58+
59+
pipe.set_progress_bar_config(disable=rank != 0)
60+
pipe_upsample.set_progress_bar_config(disable=rank != 0)
61+
62+
if args.cache or args.parallel_type is not None:
63+
cachify(args, pipe)
64+
65+
66+
def round_to_nearest_resolution_acceptable_by_vae(height, width):
67+
height = height - (height % pipe.vae_spatial_compression_ratio)
68+
width = width - (width % pipe.vae_spatial_compression_ratio)
69+
return height, width
70+
71+
72+
prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
73+
negative_prompt = (
74+
"worst quality, inconsistent motion, blurry, jittery, distorted"
75+
)
76+
expected_height, expected_width = 512, 704
77+
downscale_factor = 2 / 3
78+
num_frames = 49
79+
80+
# Part 1. Generate video at smaller resolution
81+
downscaled_height, downscaled_width = int(
82+
expected_height * downscale_factor
83+
), int(expected_width * downscale_factor)
84+
downscaled_height, downscaled_width = (
85+
round_to_nearest_resolution_acceptable_by_vae(
86+
downscaled_height, downscaled_width
87+
)
88+
)
89+
90+
91+
def run_pipe(warmup: bool = False):
92+
93+
latents = pipe(
94+
conditions=None,
95+
prompt=prompt,
96+
negative_prompt=negative_prompt,
97+
width=downscaled_width,
98+
height=downscaled_height,
99+
num_frames=num_frames,
100+
num_inference_steps=30 if not warmup else 4,
101+
generator=torch.Generator("cpu").manual_seed(0),
102+
output_type="latent",
103+
).frames
104+
105+
# Part 2. Upscale generated video using latent upsampler with fewer inference steps
106+
# The available latent upsampler upscales the height/width by 2x
107+
upscaled_height, upscaled_width = (
108+
downscaled_height * 2,
109+
downscaled_width * 2,
110+
)
111+
upscaled_latents = pipe_upsample(
112+
latents=latents, output_type="latent"
113+
).frames
114+
115+
if warmup:
116+
return None
117+
118+
# Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
119+
video = pipe(
120+
prompt=prompt,
121+
negative_prompt=negative_prompt,
122+
width=upscaled_width,
123+
height=upscaled_height,
124+
num_frames=num_frames,
125+
denoise_strength=0.4, # Effectively, 4 inference steps out of 10
126+
num_inference_steps=10,
127+
latents=upscaled_latents,
128+
decode_timestep=0.05,
129+
image_cond_noise_scale=0.025,
130+
generator=torch.Generator("cpu").manual_seed(0),
131+
output_type="pil",
132+
).frames[0]
133+
return video
134+
135+
136+
# warmup
137+
_ = run_pipe(warmup=True)
138+
139+
start = time.time()
140+
video = run_pipe()
141+
end = time.time()
142+
stats = cache_dit.summary(pipe)
143+
144+
if rank == 0:
145+
# Part 4. Downscale the video to the expected resolution
146+
video = [frame.resize((expected_width, expected_height)) for frame in video]
147+
148+
time_cost = end - start
149+
save_path = f"ltx-video.{strify(args, stats)}.mp4"
150+
print(f"Time cost: {time_cost:.2f}s")
151+
print(f"Saving video to {save_path}")
152+
export_to_video(video, save_path, fps=8)
153+
154+
maybe_destroy_distributed()

examples/parallelism/run_qwen_image_cp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from diffusers import (
99
QwenImagePipeline,
1010
QwenImageTransformer2DModel,
11+
AutoencoderKLQwenImage,
1112
)
1213

1314
from utils import (
@@ -55,8 +56,9 @@
5556
else:
5657
pipe.to(device)
5758

58-
# assert isinstance(pipe.vae, AutoencoderKLQwenImage)
59-
# pipe.vae.enable_tiling()
59+
if GiB() <= 48 and not enable_quatization:
60+
assert isinstance(pipe.vae, AutoencoderKLQwenImage)
61+
pipe.vae.enable_tiling()
6062

6163
# Apply cache and context parallelism here
6264
if args.cache or args.parallel_type is not None:

0 commit comments

Comments
 (0)