Skip to content

Refactor additions from Generative module #7227

Open
@marksgraham

Description

@marksgraham

After merging MONAI Generative into core issue some refactoring is needed to reduce repeat code, see discussion here.

Until this is done any blocks from Generative that are likely to be removed will be marked as private.

Will update the list of items needing refactoring here as I go along:

  • Upsample/Downsample/Resblock/AttentionBlock in the autoencoderkl network

    • Upsample block can be replaced by Monai's upsample block, might need to implement conv_only=True option in Monai's block
    • Downsample block can be replaced by Monai's conv, see comment here
    • The Attention block should be able to make use of Monai's self attention block
    • I don't see a simple way to replace the Resblock, many other network's implement their own versions, but maybe we should rename it to AutoencoderKLResblock to differentiate it from other Resblocks in the codebase as done in e.g. UnerR
  • SABlock and TransformerBlock used by DecoderOnlyTransformer

    • Use the internal monai versions. could make use of the new cross-attention transformer block we plan to make as described above
  • Upsample/Downsample/BasicTransformerBlock/ResnetBlock/AttentionBlock in the diffusion_model_unet

    • The CrossAttention block could be added as a block under monai.network.blocks, similarly to the SelfAttention block that already exists there
    • The SelfAttention block can be replaced with the block we already have
    • The TransformerBlock is actually a cross-attention transformer block, we could make a new monai block for it
    • Downsample block - maybe we need to add an option to the existing monai downsample block to either conv or pool, then we could use it here
    • Upsample block - we can use monai's version if we add a post_conv option to perform a convolution after the interpolation
    • Resnet - once again I think it is OK to keep this and rename to DIffusionUnetResnetBlock
  • SpadeNorm block - use get_norm_layer here and here

  • SPADEAutoencoder - merge with the autoencoder KL as the only difference is in the decoder. might make more sense to just inherit from autoencoder KL (also will get the benefit of the new load_old_state_dict metho)

  • Had to add some calls to .contiguous() in the diffusions model unet to stop issues with inferer tests pr here - need to dig deeper and find if these are really necessary, as these calls do copies and consume memory

  • ControlNet refactor suggested by @ericspod to tidy up some of the code here

  • Neater init on the patchgan discriminator suggested by @ericspod here

  • Schedulers - refactor some calculations into the base class as suggested by @KumoLiu [here](6676 port diffusion schedulers #7332 (comment)) these calculations aren't common to all the classes that inherit from scheduler (e.g. PNDM) so I'm not sure they should be moved to the base class

  • Deferred to future after discussion with @ericspod Inferers - create new base class for the Generative infererers, see discussion here

  • Deferred to future after discussion with @ericspod Engines - refactor GANTrainer into a more general base class that can replace AdversarialTrainer here

Metadata

Metadata

Labels

enhancementNew feature or requestrefactorNon-breaking feature enhancements

Type

No type

Projects

Status

In progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions