-
Notifications
You must be signed in to change notification settings - Fork 4
Closed
Labels
enhancementNew feature or requestNew feature or requestfeature reviewingevaluating features to improve codebaseevaluating features to improve codebasegood first issueGood for newcomersGood for newcomerstype:feature
Description
Short Description
Drawing inspiration from Swin UNETR and leveraging resources from MONAI, we aim to introduce a lightweight decoder module to adapt VideoSwin for 3D semantic segmentation tasks.
Minimal Code
- Encoder
import keras
from keras import layers
from keras import ops
from keras import Model
from keras.layers import Conv3D, Conv3DTranspose, Concatenate
def vswin_tiny():
# input shape, arbitrary. or use any!
model = VideoSwinBackbone(input_shape=(8, 224, 224, 3))
model.load_weights(
'videoswin_tiny_kinetics400.weights.h5', skip_mismatch=True
)
return model
encoder = vswin_tiny()- Basic U-Net
def upsample_block(x, skip, filters):
x = layers.Conv3DTranspose(
filters, (1, 2, 2), strides=(1, 2, 2), padding='same'
)(x)
if x.shape[1:4] != skip.shape[1:4]:
depth_factor = x.shape[1] // skip.shape[1]
height_factor = x.shape[2] // skip.shape[2]
width_factor = x.shape[3] // skip.shape[3]
skip = layers.UpSampling3D(
size=(depth_factor, height_factor, width_factor)
)(skip)
x = layers.Concatenate()([x, skip])
x = layers.Conv3D(filters, (3, 3, 3), activation='relu', padding='same')(x)
x = layers.Conv3D(filters, (3, 3, 3), activation='relu', padding='same')(x)
return xdef build_unet_3d(encoder, num_classes, activation="sigmoid"):
inputs = encoder.input
skips = [
encoder.get_layer("videoswin_basic_layer_1").output,
encoder.get_layer("videoswin_basic_layer_2").output,
encoder.get_layer("videoswin_basic_layer_3").output,
]
bottleneck = encoder.get_layer("top_norm").output
# decoder
x = upsample_block(bottleneck, skips[-1], 384)
x = upsample_block(x, skips[-2], 192)
x = upsample_block(x, skips[-3], 96)
x = layers.Conv3DTranspose(
96, (2, 2, 2), strides=(2, 2, 2), padding='same'
)(x)
x = layers.Conv3DTranspose(
96, (1, 2, 2), strides=(1, 2, 2), padding='same'
)(x)
# output layer
outputs = layers.Conv3D(
num_classes, (1, 1, 1), activation=activation, padding="same"
)(x)
return Model(inputs, outputs, name="UNet3D_VideoSwin")
unet_model = build_unet_3d(model, num_classes=4)
unet_model.summary(line_length=100)Model: "UNet3D_VideoSwin"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ videos (InputLayer) │ (None, 8, 224, 224, 1) │ 0 │ - │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ patching_and_embedding │ (None, 4, 56, 56, 96) │ 3,360 │ videos[0][0] │
│ (VideoSwinPatchingAndEmbed… │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ pos_drop (Dropout) │ (None, 4, 56, 56, 96) │ 0 │ patching_and_embedding… │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ videoswin_basic_layer_1 │ (None, 4, 28, 28, 192) │ 305,274 │ pos_drop[0][0] │
│ (VideoSwinBasicLayer) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ videoswin_basic_layer_2 │ (None, 4, 14, 14, 384) │ 1,200,372 │ videoswin_basic_layer_… │
│ (VideoSwinBasicLayer) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ videoswin_basic_layer_3 │ (None, 4, 7, 7, 768) │ 11,914,680 │ videoswin_basic_layer_… │
│ (VideoSwinBasicLayer) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ videoswin_basic_layer_4 │ (None, 4, 7, 7, 768) │ 14,232,528 │ videoswin_basic_layer_… │
│ (VideoSwinBasicLayer) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ top_norm │ (None, 4, 7, 7, 768) │ 1,536 │ videoswin_basic_layer_… │
│ (LayerNormalization) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_transpose │ (None, 4, 14, 14, 384) │ 1,180,032 │ top_norm[0][0] │
│ (Conv3DTranspose) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ up_sampling3d │ (None, 4, 14, 14, 768) │ 0 │ videoswin_basic_layer_… │
│ (UpSampling3D) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ concatenate (Concatenate) │ (None, 4, 14, 14, 1152) │ 0 │ conv3d_transpose[0][0], │
│ │ │ │ up_sampling3d[0][0] │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d (Conv3D) │ (None, 4, 14, 14, 384) │ 11,944,320 │ concatenate[0][0] │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_1 (Conv3D) │ (None, 4, 14, 14, 384) │ 3,981,696 │ conv3d[0][0] │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_transpose_1 │ (None, 4, 28, 28, 192) │ 295,104 │ conv3d_1[0][0] │
│ (Conv3DTranspose) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ up_sampling3d_1 │ (None, 4, 28, 28, 384) │ 0 │ videoswin_basic_layer_… │
│ (UpSampling3D) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ concatenate_1 (Concatenate) │ (None, 4, 28, 28, 576) │ 0 │ conv3d_transpose_1[0][… │
│ │ │ │ up_sampling3d_1[0][0] │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_2 (Conv3D) │ (None, 4, 28, 28, 192) │ 2,986,176 │ concatenate_1[0][0] │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_3 (Conv3D) │ (None, 4, 28, 28, 192) │ 995,520 │ conv3d_2[0][0] │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_transpose_2 │ (None, 4, 56, 56, 96) │ 73,824 │ conv3d_3[0][0] │
│ (Conv3DTranspose) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ up_sampling3d_2 │ (None, 4, 56, 56, 192) │ 0 │ videoswin_basic_layer_… │
│ (UpSampling3D) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ concatenate_2 (Concatenate) │ (None, 4, 56, 56, 288) │ 0 │ conv3d_transpose_2[0][… │
│ │ │ │ up_sampling3d_2[0][0] │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_4 (Conv3D) │ (None, 4, 56, 56, 96) │ 746,592 │ concatenate_2[0][0] │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_5 (Conv3D) │ (None, 4, 56, 56, 96) │ 248,928 │ conv3d_4[0][0] │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_transpose_3 │ (None, 8, 112, 112, 96) │ 73,824 │ conv3d_5[0][0] │
│ (Conv3DTranspose) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_transpose_4 │ (None, 8, 224, 224, 96) │ 36,960 │ conv3d_transpose_3[0][… │
│ (Conv3DTranspose) │ │ │ │
├─────────────────────────────┼─────────────────────────┼────────────────┼─────────────────────────┤
│ conv3d_6 (Conv3D) │ (None, 8, 224, 224, 4) │ 388 │ conv3d_transpose_4[0][… │
└─────────────────────────────┴─────────────────────────┴────────────────┴─────────────────────────┘
Total params: 50,221,114 (191.58 MB)
Trainable params: 50,221,114 (191.58 MB)
Non-trainable params: 0 (0.00 B)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestfeature reviewingevaluating features to improve codebaseevaluating features to improve codebasegood first issueGood for newcomersGood for newcomerstype:feature