Skip to content

Commit d4c1684

Browse files
committed
add npu optimize and apply wan vae dp
1 parent dd498c5 commit d4c1684

File tree

11 files changed

+312
-4
lines changed

11 files changed

+312
-4
lines changed

examples/parallelism/run.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
export HCCL_OP_EXPANSION_MODE="AIV"
2+
export TASK_QUEUE_ENABLE=2
3+
export CPU_AFFINITY_CONF=2
4+
5+
export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so.2:$LD_PRELOAD
6+
7+
8+
FLUX_DIR=/home/weights/FLUX.1-dev/ torchrun --nproc_per_node=1 run_flux_cp_npu.py --attn "_native_npu" --height 1024 --width 1024
9+
10+
# WAN_2_2_DIR=/home/weights/Wan2.1-T2V-14B-Diffusers/ torchrun --nproc_per_node=8 run_wan_cp_npu.py --attn "_native_npu" --height 1024 --width 1024 --steps 10 --parallel ulysses --vae-dp
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
import sys
3+
4+
sys.path.append("..")
5+
6+
import time
7+
import torch
8+
import torch_npu
9+
from torch_npu.contrib import transfer_to_npu
10+
11+
from diffusers import (
12+
FluxPipeline,
13+
FluxTransformer2DModel,
14+
PipelineQuantizationConfig,
15+
)
16+
from utils import (
17+
get_args,
18+
strify,
19+
cachify,
20+
maybe_init_distributed,
21+
maybe_destroy_distributed,
22+
)
23+
import cache_dit
24+
from cache_dit.npu_optim import npu_optimize
25+
26+
27+
npu_optimize([
28+
"npu_fast_gelu",
29+
"npu_rms_norm",
30+
"npu_layer_norm_eval",
31+
"npu_rotary_mul",
32+
])
33+
34+
args = get_args()
35+
print(args)
36+
37+
rank, device = maybe_init_distributed(args)
38+
39+
pipe: FluxPipeline = FluxPipeline.from_pretrained(
40+
os.environ.get(
41+
"FLUX_DIR",
42+
"black-forest-labs/FLUX.1-dev",
43+
),
44+
torch_dtype=torch.bfloat16,
45+
).to("cuda")
46+
47+
if args.cache or args.parallel_type is not None:
48+
cachify(args, pipe)
49+
50+
assert isinstance(pipe.transformer, FluxTransformer2DModel)
51+
52+
pipe.set_progress_bar_config(disable=rank != 0)
53+
54+
55+
def run_pipe(pipe: FluxPipeline):
56+
image = pipe(
57+
"A cat holding a sign that says hello world",
58+
height=1024 if args.height is None else args.height,
59+
width=1024 if args.width is None else args.width,
60+
num_inference_steps=28 if args.steps is None else args.steps,
61+
generator=torch.Generator("cpu").manual_seed(0),
62+
).images[0]
63+
return image
64+
65+
66+
if args.compile:
67+
cache_dit.set_compile_configs()
68+
pipe.transformer = torch.compile(pipe.transformer)
69+
70+
# warmup
71+
_ = run_pipe(pipe)
72+
73+
start = time.time()
74+
image = run_pipe(pipe)
75+
end = time.time()
76+
77+
if rank == 0:
78+
cache_dit.summary(pipe)
79+
80+
time_cost = end - start
81+
save_path = f"flux.{strify(args, pipe)}.png"
82+
print(f"Time cost: {time_cost:.2f}s")
83+
print(f"Saving image to {save_path}")
84+
image.save(save_path)
85+
86+
maybe_destroy_distributed()

examples/parallelism/run_wan_cp_npu.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121

2222
import cache_dit
23+
from cache_dit.npu_optim import npu_optimize
2324

2425

2526
def run_pipe(args, pipe, warmup: bool = False):
@@ -73,12 +74,15 @@ def main():
7374
else:
7475
pipe.to(device)
7576

77+
if args.vae_dp:
78+
pipe.vae.enable_dp(world_size=8, hw_splits=(2, 4)) # , overlap_ratio=0.01, overlap_pixels=64)
79+
7680
if args.vae_tiling:
7781
pipe.vae.enable_tiling(
78-
tile_sample_min_height=int(args.height / 2 * 3),
79-
tile_sample_min_width=int(args.width / 2 * 3),
80-
tile_sample_stride_height=int(args.height / 2),
81-
tile_sample_stride_width=int(args.width / 2),
82+
# tile_sample_min_height=int(args.height / 2 * 3),
83+
# tile_sample_min_width=int(args.width / 2 * 3),
84+
# tile_sample_stride_height=int(args.height / 2),
85+
# tile_sample_stride_width=int(args.width / 2),
8286
)
8387

8488
assert isinstance(pipe.transformer, WanTransformer3DModel)
@@ -105,4 +109,10 @@ def main():
105109

106110

107111
if __name__ == "__main__":
112+
npu_optimize([
113+
"npu_fast_gelu",
114+
"npu_rms_norm",
115+
"npu_layer_norm_eval",
116+
"npu_rotary_mul",
117+
])
108118
main()

examples/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def get_args(
8585
)
8686
parser.add_argument("--perf", action="store_true", default=False)
8787
parser.add_argument("--vae-tiling", action="store_true", default=False)
88+
parser.add_argument("--vae-dp", action="store_true", default=False)
8889
parser.add_argument("--cpu-offload", action="store_true", default=False)
8990
return parser.parse_args() if parse else parser
9091

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .utils import npu_optimize
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .npu_fast_gelu import replace_npu_fast_gelu
2+
from .npu_rms_norm import replace_npu_rms_norm
3+
from .npu_layer_norm_eval import replace_npu_layer_norm_eval
4+
from .npu_rotary_mul import replace_npu_rotary_mul
5+
6+
7+
NPU_OPTIM_MAP = {
8+
"npu_fast_gelu": replace_npu_fast_gelu,
9+
"npu_rms_norm": replace_npu_rms_norm,
10+
"npu_layer_norm_eval": replace_npu_layer_norm_eval,
11+
"npu_rotary_mul": replace_npu_rotary_mul,
12+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
import torch_npu
3+
import torch.nn as nn
4+
5+
from diffusers.models.activations import GELU as GeluDiffuser
6+
7+
from ..utils import log_replace_info
8+
9+
10+
class NpuFastGelu(nn.GELU):
11+
def forward(self, input: torch.Tensor) -> torch.Tensor:
12+
return torch_npu.npu_fast_gelu(input)
13+
14+
15+
class NpuFastGeluDiffuser(GeluDiffuser):
16+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
17+
return torch_npu.npu_fast_gelu(gate)
18+
19+
20+
def replace_func():
21+
from diffusers.models import activations
22+
activations.GELU = NpuFastGeluDiffuser
23+
24+
from torch import nn
25+
nn.GELU = NpuFastGelu
26+
27+
28+
def replace_npu_fast_gelu():
29+
replace_func()
30+
log_replace_info("nn.GELU and GELU of Diffusers", "npu_fast_gelu")
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
import torch
3+
import torch_npu
4+
5+
import torch.nn as nn
6+
7+
from ..utils import log_replace_info
8+
9+
10+
class NpuLayerNorm(nn.LayerNorm):
11+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
12+
return torch_npu.npu_layer_norm_eval(
13+
inputs,
14+
normalized_shape=self.normalized_shape,
15+
weight=self.weight,
16+
bias=self.bias,
17+
eps=self.eps,
18+
)
19+
20+
21+
def replace_func():
22+
# from torch import nn
23+
# nn.LayerNorm = NpuLayerNorm
24+
25+
from diffusers.models import normalization
26+
normalization.FP32LayerNorm = NpuLayerNorm
27+
28+
29+
def replace_npu_layer_norm_eval():
30+
replace_func()
31+
log_replace_info("FP32LayerNorm of Diffusers", "npu_layer_norm_eval")
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
import torch_npu
3+
import torch.nn as nn
4+
5+
from ..utils import log_replace_info
6+
7+
8+
class NpuRMSNorm(nn.RMSNorm):
9+
def forward(self, x: torch.Tensor) -> torch.Tensor:
10+
return torch_npu.npu_rms_norm(x, self.weight, self.eps)[0]
11+
12+
13+
def replace_func():
14+
from torch import nn
15+
nn.RMSNorm = NpuRMSNorm
16+
17+
18+
def replace_npu_rms_norm():
19+
replace_func()
20+
log_replace_info("nn.RMSNorm", "npu_rms_norm")
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Tuple, Union
2+
3+
import torch
4+
import torch_npu
5+
6+
from ..utils import log_replace_info
7+
8+
9+
def npu_apply_rotary_emb(
10+
x: torch.Tensor,
11+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
12+
use_real: bool = True,
13+
use_real_unbind_dim: int = -1,
14+
sequence_dim: int = 2,
15+
) -> Tuple[torch.Tensor, torch.Tensor]:
16+
"""
17+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
18+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
19+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
20+
tensors contain rotary embeddings and are returned as real tensors.
21+
22+
Args:
23+
x (`torch.Tensor`):
24+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
25+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
26+
27+
Returns:
28+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
29+
"""
30+
if use_real:
31+
cos, sin = freqs_cis # [S, D]
32+
if sequence_dim == 2:
33+
cos = cos[None, None, :, :]
34+
sin = sin[None, None, :, :]
35+
elif sequence_dim == 1:
36+
cos = cos[None, :, None, :]
37+
sin = sin[None, :, None, :]
38+
else:
39+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
40+
41+
cos, sin = cos.to(x.device), sin.to(x.device)
42+
43+
if use_real_unbind_dim == -1:
44+
# Used for flux, cogvideox, hunyuan-dit
45+
rotary_mode = "interleave"
46+
elif use_real_unbind_dim == -2:
47+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
48+
rotary_mode = "half"
49+
else:
50+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
51+
out = torch_npu.npu_rotary_mul(x, cos, sin, rotary_mode=rotary_mode).to(x.dtype)
52+
53+
return out
54+
else:
55+
# used for lumina
56+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
57+
freqs_cis = freqs_cis.unsqueeze(2)
58+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
59+
60+
return x_out.type_as(x)
61+
62+
63+
def replace_func():
64+
from diffusers.models import embeddings
65+
from diffusers.models.transformers import transformer_flux
66+
67+
embeddings.apply_rotary_emb = npu_apply_rotary_emb
68+
transformer_flux.apply_rotary_emb = npu_apply_rotary_emb
69+
70+
71+
def replace_npu_rotary_mul():
72+
replace_func()
73+
log_replace_info("apply_rotary_emb", "npu_rotary_mul")

0 commit comments

Comments
 (0)