Skip to content

[MICCAI 2023] Unlocking Fine-Grained Details with Wavelet-based High-Frequency Enhancement in Transformers

License

Notifications You must be signed in to change notification settings

xmindflow/WaveFormer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Unlocking Fine-Grained Details with Wavelet-based High-Frequency Enhancement in Transformers
MICCAI 2023 MLMI Workshop

arXiv

Medical image segmentation is a critical task that plays a vital role in diagnosis, treatment planning, and disease monitoring. Accurate segmentation of anatomical structures and abnormalities from medical images can aid in the early detection and treatment of various diseases. In this paper, we address the local feature deficiency of the Transformer model by carefully re-designing the self-attention map to produce accurate dense prediction in medical images. To this end, we first apply the wavelet transformation to decompose the input feature map into low-frequency (LF) and high-frequency (HF) subbands. The LF segment is associated with coarse-grained features while the HF components preserve fine-grained features such as texture and edge information. Next, we reformulate the self-attention operation using the efficient Transformer to perform both spatial and context attention on top of the frequency representation. Furthermore, to intensify the importance of the boundary information, we impose an additional attention map by creating a Gaussian pyramid on top of the HF components. Moreover, we propose a multi-scale context enhancement block within skip connections to adaptively model inter-scale dependencies to overcome the semantic gap among stages of the encoder and decoder modules. Throughout comprehensive experiments, we demonstrate the effectiveness of our strategy on multi-organ and skin lesion segmentation benchmarks.

Frequency Enhanced Transformer (FET) model
image

image
(a) FET block, (b) Multi-Scale Context Enhancement (MSCE) module

Citation

@inproceedings{azad2023unlocking,
  title={Unlocking Fine-Grained Details with Wavelet-based High-Frequency Enhancement in Transformers},
  author={Azad, Reza and Kazerouni, Amirhossein and Sulaiman, Alaa and Bozorgpour, Afshin and Aghdam, Ehsan Khodapanah and Jose, Abin, and Merhof, Dorit},
  maintitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  booktitle={Workshop on Machine Learning on Medical Imaging},
  year={2023}.
  organization={Springer}
}

News

  • Aug 18, 2023: Accepted in MICCAI 2023 MLMI Workshop! 🥳

How to use

Requirements

  • Ubuntu 16.04 or higher
  • CUDA 11.1 or higher
  • Python v3.7 or higher
  • Pytorch v1.7 or higher
  • Hardware Spec
    • A single GPU with 12GB memory or larger capacity (we used RTX 3090)
einops
h5py
imgaug
matplotlib
MedPy
numpy
opencv_python
pandas
PyWavelets
scipy
SimpleITK
tensorboardX
timm
torch
torchvision
tqdm

pip install -r requirements.txt

Model weights

You can download the learned weights in the following.

Dataset Model download link
Synapse FET [Download]

Training

For the training, you must run the train.py with your desired arguments or you can use the simple written bash script file in runs/run_tr_n01.sh. You need to change variables and arguments respectively. Below, you can find a brief description of the arguments.

usage: train.py [-h] [--root_path ROOT_PATH] [--test_path TEST_PATH] [--dataset DATASET] [--dstr_fast DSTR_FAST] [--en_lnum EN_LNUM] [--br_lnum BR_LNUM] [--de_lnum DE_LNUM]
              [--compact COMPACT] [--continue_tr CONTINUE_TR] [--optimizer OPTIMIZER] [--dice_loss_weight DICE_LOSS_WEIGHT] [--list_dir LIST_DIR] [--num_classes NUM_CLASSES]
              [--output_dir OUTPUT_DIR] [--max_iterations MAX_ITERATIONS] [--max_epochs MAX_EPOCHS] [--batch_size BATCH_SIZE] [--num_workers NUM_WORKERS]
              [--eval_interval EVAL_INTERVAL] [--model_name MODEL_NAME] [--n_gpu N_GPU] [--bridge_layers BRIDGE_LAYERS] [--deterministic DETERMINISTIC] [--base_lr BASE_LR]
              [--img_size IMG_SIZE] [--z_spacing Z_SPACING] [--seed SEED] [--opts OPTS [OPTS ...]] [--zip] [--cache-mode {no,full,part}] [--resume RESUME]
              [--accumulation-steps ACCUMULATION_STEPS] [--use-checkpoint] [--amp-opt-level {O0,O1,O2}] [--tag TAG] [--eval] [--throughput]

optional arguments:
-h, --help            show this help message and exit
--root_path ROOT_PATH
                      root dir for data
--test_path TEST_PATH
                      root dir for data
--dataset DATASET     experiment_name
--dstr_fast DSTR_FAST
                      SynapseDatasetFast: will load all data into RAM
--en_lnum EN_LNUM     en_lnum: Laplacian layers (Pyramid) for the encoder
--br_lnum BR_LNUM     br_lnum: Laplacian layers (Pyramid) for the bridge
--de_lnum DE_LNUM     de_lnum: Laplacian layers (Pyramid) for the decoder
--compact COMPACT     compact with 3 blocks instead of 4 blocks
--continue_tr CONTINUE_TR
                      continue the training from the last saved epoch
--optimizer OPTIMIZER
                      optimizer: [SGD, AdamW])
--dice_loss_weight DICE_LOSS_WEIGHT
                      You need to determine <x> (default=0.6): => [loss = (1-x)*ce_loss + x*dice_loss]
--list_dir LIST_DIR   list dir
--num_classes NUM_CLASSES
                      output channel of the network
--output_dir OUTPUT_DIR
                      output dir
--max_iterations MAX_ITERATIONS
                      maximum epoch number to train
--max_epochs MAX_EPOCHS
                      maximum epoch number to train
--batch_size BATCH_SIZE
                      batch_size per GPU
--num_workers NUM_WORKERS
                      num_workers
--eval_interval EVAL_INTERVAL
                      eval_interval
--model_name MODEL_NAME
                      model_name
--n_gpu N_GPU         total gpu
--bridge_layers BRIDGE_LAYERS
                      number of bridge layers
--deterministic DETERMINISTIC
                      whether using deterministic training
--base_lr BASE_LR     segmentation network learning rate
--img_size IMG_SIZE   input patch size of network input
--z_spacing Z_SPACING
                      z_spacing
--seed SEED           random seed
--opts OPTS [OPTS ...]
                      Modify config options by adding 'KEY VALUE' pairs.
--zip                 use zipped dataset instead of folder dataset
--cache-mode {no, full, part}
                      no: no cache, full: cache all data, part: sharding the dataset into nonoverlapping pieces and only cache one piece
--resume RESUME       resume from checkpoint
--accumulation-steps ACCUMULATION_STEPS
                      gradient accumulation steps
--use-checkpoint      whether to use gradient checkpointing to save memory
--amp-opt-level {O0,O1,O2}
                      mixed precision opt level, if O0, no amp is used
--tag TAG             tag of experiment
--eval                Perform evaluation only
--throughput          Test throughput only

Inference

For inference, you need to run the test.py. Most of the parameters are like for the train.py. You can also use runs/run_te_n01.sh for an instance.

To run with arbitrary weights you need to give the argument of --weights_fpath with the absolute path of the model weights file.

Experiments

Synapse Dataset

Synapse images Synapse results

ISIC 2018 Dataset

ISIC images ISIC results

References