Skip to content

A toolbox that provides hackable building blocks for generic 1D/2D/3D UNets, in PyTorch.

License

Notifications You must be signed in to change notification settings

archinetai/a-unet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

A-UNet

A library that provides building blocks to customize UNets, in PyTorch.

Install

pip install a-unet

PyPI - Python Version

Usage

Basic UNet

A convolutional only UNet generic to any dimension, using A-UNet blocks.
from typing import List
from a_unet import DownsampleT, Repeat, ResnetBlockT, Skip, UpsampleT
from torch import nn

def UNet(
    dim: int,
    in_channels: int,
    channels: List[int],
    factors: List[int],
    blocks: List[int],
) -> nn.Module:
    # Check lengths
    n_layers = len(channels)
    assert n_layers == len(factors) and n_layers == len(blocks), "lengths must match"

    # Define convolutional blocks types with provided dimensions
    Downsample = DownsampleT(dim=dim)
    Upsample = UpsampleT(dim=dim)

    # Resnet stack
    def Stack(channels: int, n_blocks: int) -> nn.Module:
        ResnetBlock = ResnetBlockT(dim=dim, in_channels=channels, out_channels=channels)
        resnet_stack = Repeat(ResnetBlock, times=n_blocks)
        return resnet_stack

    # Build UNet recursively
    def build(i: int) -> nn.Module:
        if i == n_layers:
            return nn.Identity()
        n_channels = channels[i - 1] if i > 0 else in_channels
        factor = factors[i]

        return Skip(
            Downsample(factor=factor, in_channels=n_channels, out_channels=channels[i]),
            Stack(channels=channels[i], n_blocks=blocks[i]),
            build(i + 1),
            Stack(channels=channels[i], n_blocks=blocks[i]),
            Upsample(factor=factor, in_channels=channels[i], out_channels=n_channels),
        )

    return build(0)
unet = UNet(dim=2, in_channels=8, channels=[256, 512], factors=[2, 2], blocks=[2, 2])
x = torch.randn(1, 8, 16, 16)
y = unet(x) # [1, 8, 16, 16]

About

A toolbox that provides hackable building blocks for generic 1D/2D/3D UNets, in PyTorch.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages