Skip to content

Commit 859b828

Browse files
committed
Updated to include depth parameter
1 parent 45b4bcd commit 859b828

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

images/unet_diagram.png

-658 KB
Loading

sm00thix_unet_U-Net.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ accelerator: "cuda-optional"
1919
import torch
2020

2121
# These are the default parameters. They are written out for clarity. Currently no pretrained weights are available.
22-
model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, in_channels=3, out_channels=1, pad=True, bilinear=True, normalization=None)
22+
model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, in_channels=3, out_channels=1, pad=True, bilinear=True, normalization=None, depth=5)
2323
# or
24-
# model = torch.hub.load('sm00thix/unet', 'unet_bn', **kwargs) # Convenience function equivalent to torch.hub.load('sm00thix/unet', 'unet', normalization='bn', **kwargs)
24+
# model = torch.hub.load('sm00thix/unet', 'unet_bn', **kwargs) # Convenience function equivalent to torch.hub.load('sm00thix/unet', 'unet', normalization='bn' **kwargs)
2525
# or
2626
# model = torch.hub.load('sm00thix/unet', 'unet_ln', **kwargs) # Convenience function equivalent to torch.hub.load('sm00thix/unet', 'unet', normalization='ln', **kwargs)
2727
# or
@@ -34,6 +34,7 @@ model = torch.hub.load('sm00thix/unet', 'unet', pretrained=False, in_channels=3,
3434
This is an implementation of U-Net [[1]](#references). It comes with the following options for customization.
3535

3636
1. Number of input and output channels
37+
3738
`in_channels` is the number of channels in the input image.
3839
`out_channels` is the number of channels in the output image.
3940
2. Upsampling
@@ -46,8 +47,11 @@ This is an implementation of U-Net [[1]](#references). It comes with the followi
4647
1. `normalization = None`: Applies no normalization.
4748
2. `normalization = "bn"`: Applies batch normalization [[2]](#references).
4849
3. `normalization = "ln"`: Applies layer normalization [[3]](#references). A permutation of dimensions is performed before the layer to ensure normalization is applied over the channel dimension. Afterward, the dimensions are permuted back to their original order.
50+
5. Depth
51+
52+
`depth` is the The depth of the U-Net. This is the number of steps in the encoder and decoder paths. This is one less than the number of downsampling and upsampling blocks. The number of intermediate channels is 64*2**`depth`, i.e. [64, 128, 256, 512, 1024] for `depth` = 5.
4953

50-
In particular, setting bilinear = False, pad = False, and normalization = None will yield the U-Net as originally designed. Generally, however, bilinear = True is recommended to avoid checkerboard artifacts.
54+
In particular, setting `bilinear = False`, `pad = False`, `normalization = None`, and `depth = 5` will yield the U-Net as originally designed. Generally, however, `bilinear = True` is recommended to avoid checkerboard artifacts.
5155

5256
As in the original implementation, all weights are initialized by sampling from a Kaiming He Normal Distribution [[4]](#references), and all biases are initialized to zero. If Batch Normalization or Layer Normalization is used, the weights of those layers are initialized to one and their biases to zero.
5357

0 commit comments

Comments
 (0)