@@ -67,6 +67,7 @@ def __init__(
6767 num_layers = 1 ,
6868 norm_num_groups = 32 ,
6969 eps = 1e-5 ,
70+ attn_impl : str = "auto" ,
7071 device : str = "cuda:0" ,
7172 dtype : torch .dtype = torch .float32 ,
7273 ):
@@ -86,6 +87,7 @@ def __init__(
8687 bias_q = True ,
8788 bias_kv = True ,
8889 bias_out = True ,
90+ attn_impl = attn_impl ,
8991 device = device ,
9092 dtype = dtype ,
9193 )
@@ -119,6 +121,7 @@ def __init__(
119121 scaling_factor : float = 0.18215 ,
120122 shift_factor : float = 0 ,
121123 use_post_quant_conv : bool = True ,
124+ attn_impl : str = "auto" ,
122125 device : str = "cuda:0" ,
123126 dtype : torch .dtype = torch .float32 ,
124127 ):
@@ -137,7 +140,7 @@ def __init__(
137140 [
138141 # UNetMidBlock2D
139142 ResnetBlock (512 , 512 , eps = 1e-6 , device = device , dtype = dtype ),
140- VAEAttentionBlock (1 , 512 , 512 , 1 , eps = 1e-6 , device = device , dtype = dtype ),
143+ VAEAttentionBlock (1 , 512 , 512 , 1 , eps = 1e-6 , device = device , dtype = dtype , attn_impl = attn_impl ),
141144 ResnetBlock (512 , 512 , eps = 1e-6 , device = device , dtype = dtype ),
142145 # UpDecoderBlock2D
143146 ResnetBlock (512 , 512 , eps = 1e-6 , device = device , dtype = dtype ),
@@ -202,6 +205,7 @@ def from_state_dict(
202205 scaling_factor : float = 0.18215 ,
203206 shift_factor : float = 0 ,
204207 use_post_quant_conv : bool = True ,
208+ attn_impl : str = "auto" ,
205209 ):
206210 with no_init_weights ():
207211 model = torch .nn .utils .skip_init (
@@ -210,6 +214,7 @@ def from_state_dict(
210214 scaling_factor = scaling_factor ,
211215 shift_factor = shift_factor ,
212216 use_post_quant_conv = use_post_quant_conv ,
217+ attn_impl = attn_impl ,
213218 device = device ,
214219 dtype = dtype ,
215220 )
@@ -230,6 +235,7 @@ def __init__(
230235 scaling_factor : float = 0.18215 ,
231236 shift_factor : float = 0 ,
232237 use_quant_conv : bool = True ,
238+ attn_impl : str = "auto" ,
233239 device : str = "cuda:0" ,
234240 dtype : torch .dtype = torch .float32 ,
235241 ):
@@ -263,7 +269,7 @@ def __init__(
263269 ResnetBlock (512 , 512 , eps = 1e-6 , device = device , dtype = dtype ),
264270 # UNetMidBlock2D
265271 ResnetBlock (512 , 512 , eps = 1e-6 , device = device , dtype = dtype ),
266- VAEAttentionBlock (1 , 512 , 512 , 1 , eps = 1e-6 , device = device , dtype = dtype ),
272+ VAEAttentionBlock (1 , 512 , 512 , 1 , eps = 1e-6 , device = device , dtype = dtype , attn_impl = attn_impl ),
267273 ResnetBlock (512 , 512 , eps = 1e-6 , device = device , dtype = dtype ),
268274 ]
269275 )
@@ -309,6 +315,7 @@ def from_state_dict(
309315 scaling_factor : float = 0.18215 ,
310316 shift_factor : float = 0 ,
311317 use_quant_conv : bool = True ,
318+ attn_impl : str = "auto" ,
312319 ):
313320 with no_init_weights ():
314321 model = torch .nn .utils .skip_init (
@@ -317,6 +324,7 @@ def from_state_dict(
317324 scaling_factor = scaling_factor ,
318325 shift_factor = shift_factor ,
319326 use_quant_conv = use_quant_conv ,
327+ attn_impl = attn_impl ,
320328 device = device ,
321329 dtype = dtype ,
322330 )
@@ -338,6 +346,7 @@ def __init__(
338346 shift_factor : float = 0 ,
339347 use_quant_conv : bool = True ,
340348 use_post_quant_conv : bool = True ,
349+ attn_impl : str = "auto" ,
341350 device : str = "cuda:0" ,
342351 dtype : torch .dtype = torch .float32 ,
343352 ):
@@ -347,6 +356,7 @@ def __init__(
347356 scaling_factor = scaling_factor ,
348357 shift_factor = shift_factor ,
349358 use_quant_conv = use_quant_conv ,
359+ attn_impl = attn_impl ,
350360 device = device ,
351361 dtype = dtype ,
352362 )
@@ -355,6 +365,7 @@ def __init__(
355365 scaling_factor = scaling_factor ,
356366 shift_factor = shift_factor ,
357367 use_post_quant_conv = use_post_quant_conv ,
368+ attn_impl = attn_impl ,
358369 device = device ,
359370 dtype = dtype ,
360371 )
@@ -376,6 +387,7 @@ def from_state_dict(
376387 shift_factor : float = 0 ,
377388 use_quant_conv : bool = True ,
378389 use_post_quant_conv : bool = True ,
390+ attn_impl : str = "auto" ,
379391 ):
380392 with no_init_weights ():
381393 model = torch .nn .utils .skip_init (
@@ -385,6 +397,7 @@ def from_state_dict(
385397 shift_factor = shift_factor ,
386398 use_quant_conv = use_quant_conv ,
387399 use_post_quant_conv = use_post_quant_conv ,
400+ attn_impl = attn_impl ,
388401 device = device ,
389402 dtype = dtype ,
390403 )
0 commit comments