Skip to content

Commit 3e65b06

Browse files
sdxl sd vae support attn impl (#52)
* sdxl sd vae support attn impl * ruff format
1 parent 1280c9f commit 3e65b06

File tree

9 files changed

+126
-55
lines changed

9 files changed

+126
-55
lines changed

diffsynth_engine/models/basic/transformer_helper.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,29 @@ def forward(self, ids):
6565

6666

6767
class RMSNorm(nn.Module):
68-
def __init__(self, dim, eps, device: str, dtype: torch.dtype):
68+
def __init__(
69+
self,
70+
dim,
71+
eps=1e-5,
72+
elementwise_affine=True,
73+
device: str = "cuda:0",
74+
dtype: torch.dtype = torch.bfloat16,
75+
):
6976
super().__init__()
70-
self.weight = nn.Parameter(torch.ones((dim,), device=device, dtype=dtype))
7177
self.eps = eps
72-
73-
def forward(self, hidden_states):
74-
input_dtype = hidden_states.dtype
75-
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
76-
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
77-
hidden_states = hidden_states.to(input_dtype) * self.weight
78-
return hidden_states
78+
self.dim = dim
79+
self.elementwise_affine = elementwise_affine
80+
if elementwise_affine:
81+
self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
82+
83+
def norm(self, x):
84+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
85+
86+
def forward(self, x):
87+
norm_result = self.norm(x.float()).to(x.dtype)
88+
if self.elementwise_affine:
89+
return norm_result * self.weight
90+
return norm_result
7991

8092

8193
class NewGELUActivation(nn.Module):

diffsynth_engine/models/components/vae.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
num_layers=1,
6868
norm_num_groups=32,
6969
eps=1e-5,
70+
attn_impl: str = "auto",
7071
device: str = "cuda:0",
7172
dtype: torch.dtype = torch.float32,
7273
):
@@ -86,6 +87,7 @@ def __init__(
8687
bias_q=True,
8788
bias_kv=True,
8889
bias_out=True,
90+
attn_impl=attn_impl,
8991
device=device,
9092
dtype=dtype,
9193
)
@@ -119,6 +121,7 @@ def __init__(
119121
scaling_factor: float = 0.18215,
120122
shift_factor: float = 0,
121123
use_post_quant_conv: bool = True,
124+
attn_impl: str = "auto",
122125
device: str = "cuda:0",
123126
dtype: torch.dtype = torch.float32,
124127
):
@@ -137,7 +140,7 @@ def __init__(
137140
[
138141
# UNetMidBlock2D
139142
ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
140-
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, device=device, dtype=dtype),
143+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, device=device, dtype=dtype, attn_impl=attn_impl),
141144
ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
142145
# UpDecoderBlock2D
143146
ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
@@ -202,6 +205,7 @@ def from_state_dict(
202205
scaling_factor: float = 0.18215,
203206
shift_factor: float = 0,
204207
use_post_quant_conv: bool = True,
208+
attn_impl: str = "auto",
205209
):
206210
with no_init_weights():
207211
model = torch.nn.utils.skip_init(
@@ -210,6 +214,7 @@ def from_state_dict(
210214
scaling_factor=scaling_factor,
211215
shift_factor=shift_factor,
212216
use_post_quant_conv=use_post_quant_conv,
217+
attn_impl=attn_impl,
213218
device=device,
214219
dtype=dtype,
215220
)
@@ -230,6 +235,7 @@ def __init__(
230235
scaling_factor: float = 0.18215,
231236
shift_factor: float = 0,
232237
use_quant_conv: bool = True,
238+
attn_impl: str = "auto",
233239
device: str = "cuda:0",
234240
dtype: torch.dtype = torch.float32,
235241
):
@@ -263,7 +269,7 @@ def __init__(
263269
ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
264270
# UNetMidBlock2D
265271
ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
266-
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, device=device, dtype=dtype),
272+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, device=device, dtype=dtype, attn_impl=attn_impl),
267273
ResnetBlock(512, 512, eps=1e-6, device=device, dtype=dtype),
268274
]
269275
)
@@ -309,6 +315,7 @@ def from_state_dict(
309315
scaling_factor: float = 0.18215,
310316
shift_factor: float = 0,
311317
use_quant_conv: bool = True,
318+
attn_impl: str = "auto",
312319
):
313320
with no_init_weights():
314321
model = torch.nn.utils.skip_init(
@@ -317,6 +324,7 @@ def from_state_dict(
317324
scaling_factor=scaling_factor,
318325
shift_factor=shift_factor,
319326
use_quant_conv=use_quant_conv,
327+
attn_impl=attn_impl,
320328
device=device,
321329
dtype=dtype,
322330
)
@@ -338,6 +346,7 @@ def __init__(
338346
shift_factor: float = 0,
339347
use_quant_conv: bool = True,
340348
use_post_quant_conv: bool = True,
349+
attn_impl: str = "auto",
341350
device: str = "cuda:0",
342351
dtype: torch.dtype = torch.float32,
343352
):
@@ -347,6 +356,7 @@ def __init__(
347356
scaling_factor=scaling_factor,
348357
shift_factor=shift_factor,
349358
use_quant_conv=use_quant_conv,
359+
attn_impl=attn_impl,
350360
device=device,
351361
dtype=dtype,
352362
)
@@ -355,6 +365,7 @@ def __init__(
355365
scaling_factor=scaling_factor,
356366
shift_factor=shift_factor,
357367
use_post_quant_conv=use_post_quant_conv,
368+
attn_impl=attn_impl,
358369
device=device,
359370
dtype=dtype,
360371
)
@@ -376,6 +387,7 @@ def from_state_dict(
376387
shift_factor: float = 0,
377388
use_quant_conv: bool = True,
378389
use_post_quant_conv: bool = True,
390+
attn_impl: str = "auto",
379391
):
380392
with no_init_weights():
381393
model = torch.nn.utils.skip_init(
@@ -385,6 +397,7 @@ def from_state_dict(
385397
shift_factor=shift_factor,
386398
use_quant_conv=use_quant_conv,
387399
use_post_quant_conv=use_post_quant_conv,
400+
attn_impl=attn_impl,
388401
device=device,
389402
dtype=dtype,
390403
)

diffsynth_engine/models/flux/flux_dit.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def __init__(
227227
nn.Linear(dim * 4, dim, device=device, dtype=dtype),
228228
)
229229

230-
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
230+
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, image_emb):
231231
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
232232
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
233233

@@ -293,7 +293,7 @@ def process_attention(self, hidden_states, image_rotary_emb):
293293
hidden_states = hidden_states.to(q.dtype)
294294
return hidden_states
295295

296-
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
296+
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, image_emb):
297297
residual = hidden_states_a
298298
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
299299
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
@@ -386,6 +386,7 @@ def forward(
386386
timestep,
387387
prompt_emb,
388388
pooled_prompt_emb,
389+
image_emb,
389390
guidance,
390391
text_ids,
391392
image_ids=None,
@@ -421,10 +422,13 @@ def forward(
421422
prompt_emb,
422423
conditioning,
423424
image_rotary_emb,
425+
image_emb,
424426
use_reentrant=False,
425427
)
426428
else:
427-
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
429+
hidden_states, prompt_emb = block(
430+
hidden_states, prompt_emb, conditioning, image_rotary_emb, image_emb
431+
)
428432
if controlnet_double_block_output is not None:
429433
interval_control = len(self.blocks) / len(controlnet_double_block_output)
430434
interval_control = int(np.ceil(interval_control))
@@ -439,10 +443,13 @@ def forward(
439443
prompt_emb,
440444
conditioning,
441445
image_rotary_emb,
446+
image_emb,
442447
use_reentrant=False,
443448
)
444449
else:
445-
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
450+
hidden_states, prompt_emb = block(
451+
hidden_states, prompt_emb, conditioning, image_rotary_emb, image_emb
452+
)
446453
if controlnet_single_block_output is not None:
447454
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
448455
interval_control = int(np.ceil(interval_control))

diffsynth_engine/models/sd/sd_vae.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,44 @@
66

77

88
class SDVAEEncoder(VAEEncoder):
9-
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
9+
def __init__(self, attn_impl: str = "auto", device: str = "cuda:0", dtype: torch.dtype = torch.float32):
1010
super().__init__(
11-
latent_channels=4, scaling_factor=0.18215, shift_factor=0, use_quant_conv=True, device=device, dtype=dtype
11+
latent_channels=4,
12+
scaling_factor=0.18215,
13+
shift_factor=0,
14+
use_quant_conv=True,
15+
attn_impl=attn_impl,
16+
device=device,
17+
dtype=dtype,
1218
)
1319

1420
@classmethod
15-
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
21+
def from_state_dict(
22+
cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, attn_impl: str = "auto"
23+
):
1624
with no_init_weights():
17-
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
25+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, attn_impl=attn_impl)
1826
model.load_state_dict(state_dict)
1927
return model
2028

2129

2230
class SDVAEDecoder(VAEDecoder):
23-
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
31+
def __init__(self, attn_impl: str = "auto", device: str = "cuda:0", dtype: torch.dtype = torch.float32):
2432
super().__init__(
2533
latent_channels=4,
2634
scaling_factor=0.18215,
2735
shift_factor=0,
2836
use_post_quant_conv=True,
37+
attn_impl=attn_impl,
2938
device=device,
3039
dtype=dtype,
3140
)
3241

3342
@classmethod
34-
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
43+
def from_state_dict(
44+
cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, attn_impl: str = "auto"
45+
):
3546
with no_init_weights():
36-
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
47+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, attn_impl=attn_impl)
3748
model.load_state_dict(state_dict)
3849
return model

diffsynth_engine/models/sdxl/sdxl_vae.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,44 @@
66

77

88
class SDXLVAEEncoder(VAEEncoder):
9-
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
9+
def __init__(self, attn_impl: str = "auto", device: str = "cuda:0", dtype: torch.dtype = torch.float32):
1010
super().__init__(
11-
latent_channels=4, scaling_factor=0.13025, shift_factor=0, use_quant_conv=True, device=device, dtype=dtype
11+
latent_channels=4,
12+
scaling_factor=0.13025,
13+
shift_factor=0,
14+
use_quant_conv=True,
15+
attn_impl=attn_impl,
16+
device=device,
17+
dtype=dtype,
1218
)
1319

1420
@classmethod
15-
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
21+
def from_state_dict(
22+
cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, attn_impl: str = "auto"
23+
):
1624
with no_init_weights():
17-
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
25+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, attn_impl=attn_impl)
1826
model.load_state_dict(state_dict)
1927
return model
2028

2129

2230
class SDXLVAEDecoder(VAEDecoder):
23-
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
31+
def __init__(self, attn_impl: str = "auto", device: str = "cuda:0", dtype: torch.dtype = torch.float32):
2432
super().__init__(
2533
latent_channels=4,
2634
scaling_factor=0.13025,
2735
shift_factor=0,
2836
use_post_quant_conv=True,
37+
attn_impl=attn_impl,
2938
device=device,
3039
dtype=dtype,
3140
)
3241

3342
@classmethod
34-
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
43+
def from_state_dict(
44+
cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, attn_impl: str = "auto"
45+
):
3546
with no_init_weights():
36-
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
47+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, attn_impl=attn_impl)
3748
model.load_state_dict(state_dict)
3849
return model

diffsynth_engine/models/wan/wan_dit.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
1010
from diffsynth_engine.models.basic.attention import attention, long_context_attention
11+
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
1112
from diffsynth_engine.models.utils import no_init_weights
1213
from diffsynth_engine.utils.constants import (
1314
WAN_DIT_1_3B_T2V_CONFIG_FILE,
@@ -57,26 +58,6 @@ def rope_apply(x, freqs):
5758
return x_out.to(x.dtype).flatten(3)
5859

5960

60-
class RMSNorm(nn.Module):
61-
def __init__(
62-
self,
63-
dim,
64-
eps=1e-5,
65-
device: str = "cuda:0",
66-
dtype: torch.dtype = torch.bfloat16,
67-
):
68-
super().__init__()
69-
self.eps = eps
70-
self.dim = dim
71-
self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
72-
73-
def norm(self, x):
74-
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
75-
76-
def forward(self, x):
77-
return self.norm(x.float()).to(x.dtype) * self.weight
78-
79-
8061
class SelfAttention(nn.Module):
8162
def __init__(
8263
self,

0 commit comments

Comments
 (0)