@@ -296,37 +296,28 @@ class Encoder(nn.Module):
296296 Args:
297297 spatial_dims: number of spatial dimensions (1D, 2D, 3D).
298298 in_channels: number of input channels.
299- num_channels: number of filters in the first downsampling .
299+ num_channels: sequence of block output channels .
300300 out_channels: number of channels in the bottom layer (latent space) of the autoencoder.
301- ch_mult: list of multipliers of num_channels in the initial layer and in each downsampling layer. Example: if
302- you want three downsamplings, you have to input a 4-element list. If you input [1, 1, 2, 2],
303- the first downsampling will leave num_channels to num_channels, the next will multiply num_channels by 2,
304- and the next will multiply num_channels*2 by 2 again, resulting in 8, 8, 16 and 32 channels.
305301 num_res_blocks: number of residual blocks (see ResBlock) per level.
306302 norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
307303 norm_eps: epsilon for the normalization.
308- attention_levels: indicate which level from ch_mult contain an attention block.
304+ attention_levels: indicate which level from num_channels contain an attention block.
309305 with_nonlocal_attn: if True use non-local attention block.
310306 """
311307
312308 def __init__ (
313309 self ,
314310 spatial_dims : int ,
315311 in_channels : int ,
316- num_channels : int ,
312+ num_channels : Sequence [ int ] ,
317313 out_channels : int ,
318- ch_mult : Sequence [int ],
319314 num_res_blocks : int ,
320315 norm_num_groups : int ,
321316 norm_eps : float ,
322- attention_levels : Optional [ Sequence [bool ]] = None ,
317+ attention_levels : Sequence [bool ],
323318 with_nonlocal_attn : bool = True ,
324319 ) -> None :
325320 super ().__init__ ()
326-
327- if attention_levels is None :
328- attention_levels = (False ,) * len (ch_mult )
329-
330321 self .spatial_dims = spatial_dims
331322 self .in_channels = in_channels
332323 self .num_channels = num_channels
@@ -336,15 +327,13 @@ def __init__(
336327 self .norm_eps = norm_eps
337328 self .attention_levels = attention_levels
338329
339- in_ch_mult = (1 ,) + tuple (ch_mult )
340-
341330 blocks = []
342331 # Initial convolution
343332 blocks .append (
344333 Convolution (
345334 spatial_dims = spatial_dims ,
346335 in_channels = in_channels ,
347- out_channels = num_channels ,
336+ out_channels = num_channels [ 0 ] ,
348337 strides = 1 ,
349338 kernel_size = 3 ,
350339 padding = 1 ,
@@ -353,52 +342,73 @@ def __init__(
353342 )
354343
355344 # Residual and downsampling blocks
356- for i in range (len (ch_mult )):
357- block_in_ch = num_channels * in_ch_mult [i ]
358- block_out_ch = num_channels * ch_mult [i ]
345+ output_channel = num_channels [0 ]
346+ for i in range (len (num_channels )):
347+ input_channel = output_channel
348+ output_channel = num_channels [i ]
349+ is_final_block = i == len (num_channels ) - 1
350+
359351 for _ in range (self .num_res_blocks ):
360352 blocks .append (
361353 ResBlock (
362354 spatial_dims = spatial_dims ,
363- in_channels = block_in_ch ,
355+ in_channels = input_channel ,
364356 norm_num_groups = norm_num_groups ,
365357 norm_eps = norm_eps ,
366- out_channels = block_out_ch ,
358+ out_channels = output_channel ,
367359 )
368360 )
369- block_in_ch = block_out_ch
361+ input_channel = output_channel
370362 if attention_levels [i ]:
371363 blocks .append (
372364 AttentionBlock (
373365 spatial_dims = spatial_dims ,
374- num_channels = block_in_ch ,
366+ num_channels = input_channel ,
375367 norm_num_groups = norm_num_groups ,
376368 norm_eps = norm_eps ,
377369 )
378370 )
379371
380- if i != len ( ch_mult ) - 1 :
381- blocks .append (Downsample (spatial_dims , block_in_ch ))
372+ if not is_final_block :
373+ blocks .append (Downsample (spatial_dims = spatial_dims , in_channels = input_channel ))
382374
383375 # Non-local attention block
384376 if with_nonlocal_attn is True :
385- blocks .append (ResBlock (spatial_dims , block_in_ch , norm_num_groups , norm_eps , block_in_ch ))
386377 blocks .append (
387- AttentionBlock (
378+ ResBlock (
388379 spatial_dims = spatial_dims ,
389- num_channels = block_in_ch ,
380+ in_channels = num_channels [ - 1 ] ,
390381 norm_num_groups = norm_num_groups ,
391382 norm_eps = norm_eps ,
383+ out_channels = num_channels [- 1 ],
392384 )
393385 )
394- blocks .append (ResBlock (spatial_dims , block_in_ch , norm_num_groups , norm_eps , block_in_ch ))
395386
387+ blocks .append (
388+ AttentionBlock (
389+ spatial_dims = spatial_dims ,
390+ num_channels = num_channels [- 1 ],
391+ norm_num_groups = norm_num_groups ,
392+ norm_eps = norm_eps ,
393+ )
394+ )
395+ blocks .append (
396+ ResBlock (
397+ spatial_dims = spatial_dims ,
398+ in_channels = num_channels [- 1 ],
399+ norm_num_groups = norm_num_groups ,
400+ norm_eps = norm_eps ,
401+ out_channels = num_channels [- 1 ],
402+ )
403+ )
396404 # Normalise and convert to latent size
397- blocks .append (nn .GroupNorm (num_groups = norm_num_groups , num_channels = block_in_ch , eps = norm_eps , affine = True ))
405+ blocks .append (
406+ nn .GroupNorm (num_groups = norm_num_groups , num_channels = num_channels [- 1 ], eps = norm_eps , affine = True )
407+ )
398408 blocks .append (
399409 Convolution (
400410 spatial_dims = self .spatial_dims ,
401- in_channels = block_in_ch ,
411+ in_channels = num_channels [ - 1 ] ,
402412 out_channels = out_channels ,
403413 strides = 1 ,
404414 kernel_size = 3 ,
@@ -421,56 +431,47 @@ class Decoder(nn.Module):
421431
422432 Args:
423433 spatial_dims: number of spatial dimensions (1D, 2D, 3D).
424- num_channels: number of filters in the last upsampling .
434+ num_channels: sequence of block output channels .
425435 in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
426436 out_channels: number of output channels.
427- ch_mult: list of multipliers of num_channels that make for all the upsampling layers before the last. In the
428- last layer, there will be a transition from num_channels to out_channels. In the layers before that,
429- channels will be the product of the previous number of channels by ch_mult.
430437 num_res_blocks: number of residual blocks (see ResBlock) per level.
431438 norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
432439 norm_eps: epsilon for the normalization.
433- attention_levels: indicate which level from ch_mult contain an attention block.
440+ attention_levels: indicate which level from num_channels contain an attention block.
434441 with_nonlocal_attn: if True use non-local attention block.
435442 """
436443
437444 def __init__ (
438445 self ,
439446 spatial_dims : int ,
440- num_channels : int ,
447+ num_channels : Sequence [ int ] ,
441448 in_channels : int ,
442449 out_channels : int ,
443- ch_mult : Sequence [int ],
444450 num_res_blocks : int ,
445451 norm_num_groups : int ,
446452 norm_eps : float ,
447- attention_levels : Optional [ Sequence [bool ]] = None ,
453+ attention_levels : Sequence [bool ],
448454 with_nonlocal_attn : bool = True ,
449455 ) -> None :
450456 super ().__init__ ()
451-
452- if attention_levels is None :
453- attention_levels = (False ,) * len (ch_mult )
454-
455457 self .spatial_dims = spatial_dims
456458 self .num_channels = num_channels
457459 self .in_channels = in_channels
458460 self .out_channels = out_channels
459- self .ch_mult = ch_mult
460461 self .num_res_blocks = num_res_blocks
461462 self .norm_num_groups = norm_num_groups
462463 self .norm_eps = norm_eps
463464 self .attention_levels = attention_levels
464465
465- block_in_ch = num_channels * self . ch_mult [ - 1 ]
466+ reversed_block_out_channels = list ( reversed ( num_channels ))
466467
467468 blocks = []
468469 # Initial convolution
469470 blocks .append (
470471 Convolution (
471472 spatial_dims = spatial_dims ,
472473 in_channels = in_channels ,
473- out_channels = block_in_ch ,
474+ out_channels = reversed_block_out_channels [ 0 ] ,
474475 strides = 1 ,
475476 kernel_size = 3 ,
476477 padding = 1 ,
@@ -480,25 +481,53 @@ def __init__(
480481
481482 # Non-local attention block
482483 if with_nonlocal_attn is True :
483- blocks .append (ResBlock (spatial_dims , block_in_ch , norm_num_groups , norm_eps , block_in_ch ))
484+ blocks .append (
485+ ResBlock (
486+ spatial_dims = spatial_dims ,
487+ in_channels = reversed_block_out_channels [0 ],
488+ norm_num_groups = norm_num_groups ,
489+ norm_eps = norm_eps ,
490+ out_channels = reversed_block_out_channels [0 ],
491+ )
492+ )
484493 blocks .append (
485494 AttentionBlock (
486495 spatial_dims = spatial_dims ,
487- num_channels = block_in_ch ,
496+ num_channels = reversed_block_out_channels [ 0 ] ,
488497 norm_num_groups = norm_num_groups ,
489498 norm_eps = norm_eps ,
490499 )
491500 )
492- blocks .append (ResBlock (spatial_dims , block_in_ch , norm_num_groups , norm_eps , block_in_ch ))
501+ blocks .append (
502+ ResBlock (
503+ spatial_dims = spatial_dims ,
504+ in_channels = reversed_block_out_channels [0 ],
505+ norm_num_groups = norm_num_groups ,
506+ norm_eps = norm_eps ,
507+ out_channels = reversed_block_out_channels [0 ],
508+ )
509+ )
493510
494- for i in reversed (range (len (ch_mult ))):
495- block_out_ch = num_channels * self .ch_mult [i ]
511+ reversed_attention_levels = list (reversed (attention_levels ))
512+ block_out_ch = reversed_block_out_channels [0 ]
513+ for i in range (len (reversed_block_out_channels )):
514+ block_in_ch = block_out_ch
515+ block_out_ch = reversed_block_out_channels [i ]
516+ is_final_block = i == len (num_channels ) - 1
496517
497518 for _ in range (self .num_res_blocks ):
498- blocks .append (ResBlock (spatial_dims , block_in_ch , norm_num_groups , norm_eps , block_out_ch ))
519+ blocks .append (
520+ ResBlock (
521+ spatial_dims = spatial_dims ,
522+ in_channels = block_in_ch ,
523+ norm_num_groups = norm_num_groups ,
524+ norm_eps = norm_eps ,
525+ out_channels = block_out_ch ,
526+ )
527+ )
499528 block_in_ch = block_out_ch
500529
501- if attention_levels [i ]:
530+ if reversed_attention_levels [i ]:
502531 blocks .append (
503532 AttentionBlock (
504533 spatial_dims = spatial_dims ,
@@ -508,8 +537,8 @@ def __init__(
508537 )
509538 )
510539
511- if i != 0 :
512- blocks .append (Upsample (spatial_dims , block_in_ch ))
540+ if not is_final_block :
541+ blocks .append (Upsample (spatial_dims = spatial_dims , in_channels = block_in_ch ))
513542
514543 blocks .append (nn .GroupNorm (num_groups = norm_num_groups , num_channels = block_in_ch , eps = norm_eps , affine = True ))
515544 blocks .append (
@@ -542,50 +571,44 @@ class AutoencoderKL(nn.Module):
542571 spatial_dims: number of spatial dimensions (1D, 2D, 3D).
543572 in_channels: number of input channels.
544573 out_channels: number of output channels.
545- num_channels: number of filters in the first downsampling / last upsampling.
546- latent_channels: latent embedding dimension.
547- ch_mult: multiplier of the number of channels in each downsampling layer (+ initial one). i.e.: If you want 3
548- downsamplings, it should be a 4-element list.
549574 num_res_blocks: number of residual blocks (see ResBlock) per level.
575+ num_channels: sequence of block output channels.
576+ attention_levels: sequence of levels to add attention.
577+ latent_channels: latent embedding dimension.
550578 norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
551579 norm_eps: epsilon for the normalization.
552- attention_levels: indicate which level from ch_mult contain an attention block.
553580 with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
554581 with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
555582 """
556583
557584 def __init__ (
558585 self ,
559586 spatial_dims : int ,
560- in_channels : int ,
561- out_channels : int ,
562- num_channels : int ,
563- latent_channels : int ,
564- ch_mult : Sequence [int ] ,
565- num_res_blocks : int ,
587+ in_channels : int = 1 ,
588+ out_channels : int = 1 ,
589+ num_res_blocks : int = 2 ,
590+ num_channels : Sequence [ int ] = ( 32 , 64 , 64 , 64 ) ,
591+ attention_levels : Sequence [bool ] = ( False , False , True , True ) ,
592+ latent_channels : int = 3 ,
566593 norm_num_groups : int = 32 ,
567594 norm_eps : float = 1e-6 ,
568- attention_levels : Optional [Sequence [bool ]] = None ,
569595 with_encoder_nonlocal_attn : bool = True ,
570596 with_decoder_nonlocal_attn : bool = True ,
571597 ) -> None :
572598 super ().__init__ ()
573- if attention_levels is None :
574- attention_levels = (False ,) * len (ch_mult )
575599
576- # The number of channels should be multiple of num_groups
577- if ( num_channels % norm_num_groups ) != 0 :
578- raise ValueError ("AutoencoderKL expects number of channels being multiple of number of groups " )
600+ # All number of channels should be multiple of num_groups
601+ if any (( out_channel % norm_num_groups ) != 0 for out_channel in num_channels ) :
602+ raise ValueError ("AutoencoderKL expects all num_channels being multiple of norm_num_groups " )
579603
580- if len (ch_mult ) != len (attention_levels ):
581- raise ValueError ("AutoencoderKL expects ch_mult being same size of attention_levels" )
604+ if len (num_channels ) != len (attention_levels ):
605+ raise ValueError ("AutoencoderKL expects num_channels being same size of attention_levels" )
582606
583607 self .encoder = Encoder (
584608 spatial_dims = spatial_dims ,
585609 in_channels = in_channels ,
586610 num_channels = num_channels ,
587611 out_channels = latent_channels ,
588- ch_mult = ch_mult ,
589612 num_res_blocks = num_res_blocks ,
590613 norm_num_groups = norm_num_groups ,
591614 norm_eps = norm_eps ,
@@ -597,7 +620,6 @@ def __init__(
597620 num_channels = num_channels ,
598621 in_channels = latent_channels ,
599622 out_channels = out_channels ,
600- ch_mult = ch_mult ,
601623 num_res_blocks = num_res_blocks ,
602624 norm_num_groups = norm_num_groups ,
603625 norm_eps = norm_eps ,
0 commit comments