@@ -102,7 +102,7 @@ def get_block(
102102 attention_head_dim : int ,
103103 norm_type : str ,
104104 act_fn : str ,
105- qkv_mutliscales : Tuple [int ] = (),
105+ qkv_mutliscales : Tuple [int , ... ] = (),
106106):
107107 if block_type == "ResBlock" :
108108 block = ResBlock (in_channels , out_channels , norm_type , act_fn )
@@ -206,8 +206,8 @@ def __init__(
206206 latent_channels : int ,
207207 attention_head_dim : int = 32 ,
208208 block_type : Union [str , Tuple [str ]] = "ResBlock" ,
209- block_out_channels : Tuple [int ] = (128 , 256 , 512 , 512 , 1024 , 1024 ),
210- layers_per_block : Tuple [int ] = (2 , 2 , 2 , 2 , 2 , 2 ),
209+ block_out_channels : Tuple [int , ... ] = (128 , 256 , 512 , 512 , 1024 , 1024 ),
210+ layers_per_block : Tuple [int , ... ] = (2 , 2 , 2 , 2 , 2 , 2 ),
211211 qkv_multiscales : Tuple [Tuple [int , ...], ...] = ((), (), (), (5 ,), (5 ,), (5 ,)),
212212 downsample_block_type : str = "pixel_unshuffle" ,
213213 out_shortcut : bool = True ,
@@ -292,8 +292,8 @@ def __init__(
292292 latent_channels : int ,
293293 attention_head_dim : int = 32 ,
294294 block_type : Union [str , Tuple [str ]] = "ResBlock" ,
295- block_out_channels : Tuple [int ] = (128 , 256 , 512 , 512 , 1024 , 1024 ),
296- layers_per_block : Tuple [int ] = (2 , 2 , 2 , 2 , 2 , 2 ),
295+ block_out_channels : Tuple [int , ... ] = (128 , 256 , 512 , 512 , 1024 , 1024 ),
296+ layers_per_block : Tuple [int , ... ] = (2 , 2 , 2 , 2 , 2 , 2 ),
297297 qkv_multiscales : Tuple [Tuple [int , ...], ...] = ((), (), (), (5 ,), (5 ,), (5 ,)),
298298 norm_type : Union [str , Tuple [str ]] = "rms_norm" ,
299299 act_fn : Union [str , Tuple [str ]] = "silu" ,
@@ -440,8 +440,8 @@ def __init__(
440440 decoder_block_types : Union [str , Tuple [str ]] = "ResBlock" ,
441441 encoder_block_out_channels : Tuple [int , ...] = (128 , 256 , 512 , 512 , 1024 , 1024 ),
442442 decoder_block_out_channels : Tuple [int , ...] = (128 , 256 , 512 , 512 , 1024 , 1024 ),
443- encoder_layers_per_block : Tuple [int ] = (2 , 2 , 2 , 3 , 3 , 3 ),
444- decoder_layers_per_block : Tuple [int ] = (3 , 3 , 3 , 3 , 3 , 3 ),
443+ encoder_layers_per_block : Tuple [int , ... ] = (2 , 2 , 2 , 3 , 3 , 3 ),
444+ decoder_layers_per_block : Tuple [int , ... ] = (3 , 3 , 3 , 3 , 3 , 3 ),
445445 encoder_qkv_multiscales : Tuple [Tuple [int , ...], ...] = ((), (), (), (5 ,), (5 ,), (5 ,)),
446446 decoder_qkv_multiscales : Tuple [Tuple [int , ...], ...] = ((), (), (), (5 ,), (5 ,), (5 ,)),
447447 upsample_block_type : str = "pixel_shuffle" ,
0 commit comments