Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit cecdce3

Browse files
authored
Remove ch_mult from AutoencoderKL (#220)
* Change num_channels to Sequence * Update tutorials
1 parent 7992e5a commit cecdce3

File tree

7 files changed

+115
-105
lines changed

7 files changed

+115
-105
lines changed

generative/networks/nets/autoencoderkl.py

Lines changed: 98 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -296,37 +296,28 @@ class Encoder(nn.Module):
296296
Args:
297297
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
298298
in_channels: number of input channels.
299-
num_channels: number of filters in the first downsampling.
299+
num_channels: sequence of block output channels.
300300
out_channels: number of channels in the bottom layer (latent space) of the autoencoder.
301-
ch_mult: list of multipliers of num_channels in the initial layer and in each downsampling layer. Example: if
302-
you want three downsamplings, you have to input a 4-element list. If you input [1, 1, 2, 2],
303-
the first downsampling will leave num_channels to num_channels, the next will multiply num_channels by 2,
304-
and the next will multiply num_channels*2 by 2 again, resulting in 8, 8, 16 and 32 channels.
305301
num_res_blocks: number of residual blocks (see ResBlock) per level.
306302
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
307303
norm_eps: epsilon for the normalization.
308-
attention_levels: indicate which level from ch_mult contain an attention block.
304+
attention_levels: indicate which level from num_channels contain an attention block.
309305
with_nonlocal_attn: if True use non-local attention block.
310306
"""
311307

312308
def __init__(
313309
self,
314310
spatial_dims: int,
315311
in_channels: int,
316-
num_channels: int,
312+
num_channels: Sequence[int],
317313
out_channels: int,
318-
ch_mult: Sequence[int],
319314
num_res_blocks: int,
320315
norm_num_groups: int,
321316
norm_eps: float,
322-
attention_levels: Optional[Sequence[bool]] = None,
317+
attention_levels: Sequence[bool],
323318
with_nonlocal_attn: bool = True,
324319
) -> None:
325320
super().__init__()
326-
327-
if attention_levels is None:
328-
attention_levels = (False,) * len(ch_mult)
329-
330321
self.spatial_dims = spatial_dims
331322
self.in_channels = in_channels
332323
self.num_channels = num_channels
@@ -336,15 +327,13 @@ def __init__(
336327
self.norm_eps = norm_eps
337328
self.attention_levels = attention_levels
338329

339-
in_ch_mult = (1,) + tuple(ch_mult)
340-
341330
blocks = []
342331
# Initial convolution
343332
blocks.append(
344333
Convolution(
345334
spatial_dims=spatial_dims,
346335
in_channels=in_channels,
347-
out_channels=num_channels,
336+
out_channels=num_channels[0],
348337
strides=1,
349338
kernel_size=3,
350339
padding=1,
@@ -353,52 +342,73 @@ def __init__(
353342
)
354343

355344
# Residual and downsampling blocks
356-
for i in range(len(ch_mult)):
357-
block_in_ch = num_channels * in_ch_mult[i]
358-
block_out_ch = num_channels * ch_mult[i]
345+
output_channel = num_channels[0]
346+
for i in range(len(num_channels)):
347+
input_channel = output_channel
348+
output_channel = num_channels[i]
349+
is_final_block = i == len(num_channels) - 1
350+
359351
for _ in range(self.num_res_blocks):
360352
blocks.append(
361353
ResBlock(
362354
spatial_dims=spatial_dims,
363-
in_channels=block_in_ch,
355+
in_channels=input_channel,
364356
norm_num_groups=norm_num_groups,
365357
norm_eps=norm_eps,
366-
out_channels=block_out_ch,
358+
out_channels=output_channel,
367359
)
368360
)
369-
block_in_ch = block_out_ch
361+
input_channel = output_channel
370362
if attention_levels[i]:
371363
blocks.append(
372364
AttentionBlock(
373365
spatial_dims=spatial_dims,
374-
num_channels=block_in_ch,
366+
num_channels=input_channel,
375367
norm_num_groups=norm_num_groups,
376368
norm_eps=norm_eps,
377369
)
378370
)
379371

380-
if i != len(ch_mult) - 1:
381-
blocks.append(Downsample(spatial_dims, block_in_ch))
372+
if not is_final_block:
373+
blocks.append(Downsample(spatial_dims=spatial_dims, in_channels=input_channel))
382374

383375
# Non-local attention block
384376
if with_nonlocal_attn is True:
385-
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
386377
blocks.append(
387-
AttentionBlock(
378+
ResBlock(
388379
spatial_dims=spatial_dims,
389-
num_channels=block_in_ch,
380+
in_channels=num_channels[-1],
390381
norm_num_groups=norm_num_groups,
391382
norm_eps=norm_eps,
383+
out_channels=num_channels[-1],
392384
)
393385
)
394-
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
395386

387+
blocks.append(
388+
AttentionBlock(
389+
spatial_dims=spatial_dims,
390+
num_channels=num_channels[-1],
391+
norm_num_groups=norm_num_groups,
392+
norm_eps=norm_eps,
393+
)
394+
)
395+
blocks.append(
396+
ResBlock(
397+
spatial_dims=spatial_dims,
398+
in_channels=num_channels[-1],
399+
norm_num_groups=norm_num_groups,
400+
norm_eps=norm_eps,
401+
out_channels=num_channels[-1],
402+
)
403+
)
396404
# Normalise and convert to latent size
397-
blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
405+
blocks.append(
406+
nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True)
407+
)
398408
blocks.append(
399409
Convolution(
400410
spatial_dims=self.spatial_dims,
401-
in_channels=block_in_ch,
411+
in_channels=num_channels[-1],
402412
out_channels=out_channels,
403413
strides=1,
404414
kernel_size=3,
@@ -421,56 +431,47 @@ class Decoder(nn.Module):
421431
422432
Args:
423433
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
424-
num_channels: number of filters in the last upsampling.
434+
num_channels: sequence of block output channels.
425435
in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
426436
out_channels: number of output channels.
427-
ch_mult: list of multipliers of num_channels that make for all the upsampling layers before the last. In the
428-
last layer, there will be a transition from num_channels to out_channels. In the layers before that,
429-
channels will be the product of the previous number of channels by ch_mult.
430437
num_res_blocks: number of residual blocks (see ResBlock) per level.
431438
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
432439
norm_eps: epsilon for the normalization.
433-
attention_levels: indicate which level from ch_mult contain an attention block.
440+
attention_levels: indicate which level from num_channels contain an attention block.
434441
with_nonlocal_attn: if True use non-local attention block.
435442
"""
436443

437444
def __init__(
438445
self,
439446
spatial_dims: int,
440-
num_channels: int,
447+
num_channels: Sequence[int],
441448
in_channels: int,
442449
out_channels: int,
443-
ch_mult: Sequence[int],
444450
num_res_blocks: int,
445451
norm_num_groups: int,
446452
norm_eps: float,
447-
attention_levels: Optional[Sequence[bool]] = None,
453+
attention_levels: Sequence[bool],
448454
with_nonlocal_attn: bool = True,
449455
) -> None:
450456
super().__init__()
451-
452-
if attention_levels is None:
453-
attention_levels = (False,) * len(ch_mult)
454-
455457
self.spatial_dims = spatial_dims
456458
self.num_channels = num_channels
457459
self.in_channels = in_channels
458460
self.out_channels = out_channels
459-
self.ch_mult = ch_mult
460461
self.num_res_blocks = num_res_blocks
461462
self.norm_num_groups = norm_num_groups
462463
self.norm_eps = norm_eps
463464
self.attention_levels = attention_levels
464465

465-
block_in_ch = num_channels * self.ch_mult[-1]
466+
reversed_block_out_channels = list(reversed(num_channels))
466467

467468
blocks = []
468469
# Initial convolution
469470
blocks.append(
470471
Convolution(
471472
spatial_dims=spatial_dims,
472473
in_channels=in_channels,
473-
out_channels=block_in_ch,
474+
out_channels=reversed_block_out_channels[0],
474475
strides=1,
475476
kernel_size=3,
476477
padding=1,
@@ -480,25 +481,53 @@ def __init__(
480481

481482
# Non-local attention block
482483
if with_nonlocal_attn is True:
483-
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
484+
blocks.append(
485+
ResBlock(
486+
spatial_dims=spatial_dims,
487+
in_channels=reversed_block_out_channels[0],
488+
norm_num_groups=norm_num_groups,
489+
norm_eps=norm_eps,
490+
out_channels=reversed_block_out_channels[0],
491+
)
492+
)
484493
blocks.append(
485494
AttentionBlock(
486495
spatial_dims=spatial_dims,
487-
num_channels=block_in_ch,
496+
num_channels=reversed_block_out_channels[0],
488497
norm_num_groups=norm_num_groups,
489498
norm_eps=norm_eps,
490499
)
491500
)
492-
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
501+
blocks.append(
502+
ResBlock(
503+
spatial_dims=spatial_dims,
504+
in_channels=reversed_block_out_channels[0],
505+
norm_num_groups=norm_num_groups,
506+
norm_eps=norm_eps,
507+
out_channels=reversed_block_out_channels[0],
508+
)
509+
)
493510

494-
for i in reversed(range(len(ch_mult))):
495-
block_out_ch = num_channels * self.ch_mult[i]
511+
reversed_attention_levels = list(reversed(attention_levels))
512+
block_out_ch = reversed_block_out_channels[0]
513+
for i in range(len(reversed_block_out_channels)):
514+
block_in_ch = block_out_ch
515+
block_out_ch = reversed_block_out_channels[i]
516+
is_final_block = i == len(num_channels) - 1
496517

497518
for _ in range(self.num_res_blocks):
498-
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_out_ch))
519+
blocks.append(
520+
ResBlock(
521+
spatial_dims=spatial_dims,
522+
in_channels=block_in_ch,
523+
norm_num_groups=norm_num_groups,
524+
norm_eps=norm_eps,
525+
out_channels=block_out_ch,
526+
)
527+
)
499528
block_in_ch = block_out_ch
500529

501-
if attention_levels[i]:
530+
if reversed_attention_levels[i]:
502531
blocks.append(
503532
AttentionBlock(
504533
spatial_dims=spatial_dims,
@@ -508,8 +537,8 @@ def __init__(
508537
)
509538
)
510539

511-
if i != 0:
512-
blocks.append(Upsample(spatial_dims, block_in_ch))
540+
if not is_final_block:
541+
blocks.append(Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch))
513542

514543
blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
515544
blocks.append(
@@ -542,50 +571,44 @@ class AutoencoderKL(nn.Module):
542571
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
543572
in_channels: number of input channels.
544573
out_channels: number of output channels.
545-
num_channels: number of filters in the first downsampling / last upsampling.
546-
latent_channels: latent embedding dimension.
547-
ch_mult: multiplier of the number of channels in each downsampling layer (+ initial one). i.e.: If you want 3
548-
downsamplings, it should be a 4-element list.
549574
num_res_blocks: number of residual blocks (see ResBlock) per level.
575+
num_channels: sequence of block output channels.
576+
attention_levels: sequence of levels to add attention.
577+
latent_channels: latent embedding dimension.
550578
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
551579
norm_eps: epsilon for the normalization.
552-
attention_levels: indicate which level from ch_mult contain an attention block.
553580
with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
554581
with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
555582
"""
556583

557584
def __init__(
558585
self,
559586
spatial_dims: int,
560-
in_channels: int,
561-
out_channels: int,
562-
num_channels: int,
563-
latent_channels: int,
564-
ch_mult: Sequence[int],
565-
num_res_blocks: int,
587+
in_channels: int = 1,
588+
out_channels: int = 1,
589+
num_res_blocks: int = 2,
590+
num_channels: Sequence[int] = (32, 64, 64, 64),
591+
attention_levels: Sequence[bool] = (False, False, True, True),
592+
latent_channels: int = 3,
566593
norm_num_groups: int = 32,
567594
norm_eps: float = 1e-6,
568-
attention_levels: Optional[Sequence[bool]] = None,
569595
with_encoder_nonlocal_attn: bool = True,
570596
with_decoder_nonlocal_attn: bool = True,
571597
) -> None:
572598
super().__init__()
573-
if attention_levels is None:
574-
attention_levels = (False,) * len(ch_mult)
575599

576-
# The number of channels should be multiple of num_groups
577-
if (num_channels % norm_num_groups) != 0:
578-
raise ValueError("AutoencoderKL expects number of channels being multiple of number of groups")
600+
# All number of channels should be multiple of num_groups
601+
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
602+
raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups")
579603

580-
if len(ch_mult) != len(attention_levels):
581-
raise ValueError("AutoencoderKL expects ch_mult being same size of attention_levels")
604+
if len(num_channels) != len(attention_levels):
605+
raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels")
582606

583607
self.encoder = Encoder(
584608
spatial_dims=spatial_dims,
585609
in_channels=in_channels,
586610
num_channels=num_channels,
587611
out_channels=latent_channels,
588-
ch_mult=ch_mult,
589612
num_res_blocks=num_res_blocks,
590613
norm_num_groups=norm_num_groups,
591614
norm_eps=norm_eps,
@@ -597,7 +620,6 @@ def __init__(
597620
num_channels=num_channels,
598621
in_channels=latent_channels,
599622
out_channels=out_channels,
600-
ch_mult=ch_mult,
601623
num_res_blocks=num_res_blocks,
602624
norm_num_groups=norm_num_groups,
603625
norm_eps=norm_eps,

0 commit comments

Comments
 (0)