Skip to content

Semantic Segmentation #8

@innat

Description

@innat

Short Description

Drawing inspiration from Swin UNETR and leveraging resources from MONAI, we aim to introduce a lightweight decoder module to adapt VideoSwin for 3D semantic segmentation tasks.

Minimal Code

  1. Encoder
import keras
from keras import layers
from keras import ops
from keras import Model
from keras.layers import Conv3D, Conv3DTranspose, Concatenate

def vswin_tiny():
    # input shape, arbitrary. or use any!
    model = VideoSwinBackbone(input_shape=(8, 224, 224, 3))
    model.load_weights(
        'videoswin_tiny_kinetics400.weights.h5', skip_mismatch=True
    )
    return model

encoder = vswin_tiny()
  1. Basic U-Net
def upsample_block(x, skip, filters):
    x = layers.Conv3DTranspose(
        filters, (1, 2, 2), strides=(1, 2, 2), padding='same'
    )(x)

    if x.shape[1:4] != skip.shape[1:4]:  
        depth_factor = x.shape[1] // skip.shape[1]
        height_factor = x.shape[2] // skip.shape[2]
        width_factor = x.shape[3] // skip.shape[3]
        skip = layers.UpSampling3D(
            size=(depth_factor, height_factor, width_factor)
        )(skip)

    x = layers.Concatenate()([x, skip])
    x = layers.Conv3D(filters, (3, 3, 3), activation='relu', padding='same')(x)
    x = layers.Conv3D(filters, (3, 3, 3), activation='relu', padding='same')(x)
    return x
def build_unet_3d(encoder, num_classes, activation="sigmoid"):
    inputs = encoder.input
    skips = [
        encoder.get_layer("videoswin_basic_layer_1").output,
        encoder.get_layer("videoswin_basic_layer_2").output,
        encoder.get_layer("videoswin_basic_layer_3").output,
    ]
    bottleneck = encoder.get_layer("top_norm").output 

    # decoder
    x = upsample_block(bottleneck, skips[-1], 384)
    x = upsample_block(x, skips[-2], 192)
    x = upsample_block(x, skips[-3], 96)

    x = layers.Conv3DTranspose(
        96, (2, 2, 2), strides=(2, 2, 2), padding='same'
    )(x)

    x = layers.Conv3DTranspose(
        96, (1, 2, 2), strides=(1, 2, 2), padding='same'
    )(x)

    # output layer
    outputs = layers.Conv3D(
        num_classes, (1, 1, 1), activation=activation, padding="same"
    )(x)
    return Model(inputs, outputs, name="UNet3D_VideoSwin")

unet_model = build_unet_3d(model, num_classes=4)
unet_model.summary(line_length=100)
Model: "UNet3D_VideoSwin"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                ┃ Output ShapeParam # ┃ Connected to            ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ videos (InputLayer)         │ (None, 8, 224, 224, 1)  │              0-                       │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ patching_and_embedding      │ (None, 4, 56, 56, 96)   │          3,360videos[0][0]            │
│ (VideoSwinPatchingAndEmbed… │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ pos_drop (Dropout)          │ (None, 4, 56, 56, 96)   │              0patching_and_embedding… │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ videoswin_basic_layer_1     │ (None, 4, 28, 28, 192)  │        305,274pos_drop[0][0]          │
│ (VideoSwinBasicLayer)       │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ videoswin_basic_layer_2     │ (None, 4, 14, 14, 384)  │      1,200,372videoswin_basic_layer_… │
│ (VideoSwinBasicLayer)       │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ videoswin_basic_layer_3     │ (None, 4, 7, 7, 768)    │     11,914,680videoswin_basic_layer_… │
│ (VideoSwinBasicLayer)       │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ videoswin_basic_layer_4     │ (None, 4, 7, 7, 768)    │     14,232,528videoswin_basic_layer_… │
│ (VideoSwinBasicLayer)       │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ top_norm                    │ (None, 4, 7, 7, 768)    │          1,536videoswin_basic_layer_… │
│ (LayerNormalization)        │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_transpose            │ (None, 4, 14, 14, 384)  │      1,180,032top_norm[0][0]          │
│ (Conv3DTranspose)           │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ up_sampling3d               │ (None, 4, 14, 14, 768)  │              0videoswin_basic_layer_… │
│ (UpSampling3D)              │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ concatenate (Concatenate)   │ (None, 4, 14, 14, 1152) │              0conv3d_transpose[0][0], │
│                             │                         │                │ up_sampling3d[0][0]     │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d (Conv3D)             │ (None, 4, 14, 14, 384)  │     11,944,320concatenate[0][0]       │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_1 (Conv3D)           │ (None, 4, 14, 14, 384)  │      3,981,696conv3d[0][0]            │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_transpose_1          │ (None, 4, 28, 28, 192)  │        295,104conv3d_1[0][0]          │
│ (Conv3DTranspose)           │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ up_sampling3d_1             │ (None, 4, 28, 28, 384)  │              0videoswin_basic_layer_… │
│ (UpSampling3D)              │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ concatenate_1 (Concatenate) │ (None, 4, 28, 28, 576)  │              0conv3d_transpose_1[0][… │
│                             │                         │                │ up_sampling3d_1[0][0]   │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_2 (Conv3D)           │ (None, 4, 28, 28, 192)  │      2,986,176concatenate_1[0][0]     │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_3 (Conv3D)           │ (None, 4, 28, 28, 192)  │        995,520conv3d_2[0][0]          │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_transpose_2          │ (None, 4, 56, 56, 96)   │         73,824conv3d_3[0][0]          │
│ (Conv3DTranspose)           │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ up_sampling3d_2             │ (None, 4, 56, 56, 192)  │              0videoswin_basic_layer_… │
│ (UpSampling3D)              │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ concatenate_2 (Concatenate) │ (None, 4, 56, 56, 288)  │              0conv3d_transpose_2[0][… │
│                             │                         │                │ up_sampling3d_2[0][0]   │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_4 (Conv3D)           │ (None, 4, 56, 56, 96)   │        746,592concatenate_2[0][0]     │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_5 (Conv3D)           │ (None, 4, 56, 56, 96)   │        248,928conv3d_4[0][0]          │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_transpose_3          │ (None, 8, 112, 112, 96) │         73,824conv3d_5[0][0]          │
│ (Conv3DTranspose)           │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_transpose_4          │ (None, 8, 224, 224, 96) │         36,960conv3d_transpose_3[0][… │
│ (Conv3DTranspose)           │                         │                │                         │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_6 (Conv3D)           │ (None, 8, 224, 224, 4)  │            388conv3d_transpose_4[0][… │
└─────────────────────────────┴─────────────────────────┴────────────────┴─────────────────────────┘
 Total params: 50,221,114 (191.58 MB)
 Trainable params: 50,221,114 (191.58 MB)
 Non-trainable params: 0 (0.00 B)

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions