Skip to content

Commit

Permalink
Add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed May 25, 2024
1 parent 9c0d164 commit b7ad774
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 3 deletions.
19 changes: 16 additions & 3 deletions sleap_nn/architectures/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,22 @@ def __init__(
if middle_block:
if convs_per_block > 1:
# First convs are one exponent higher than the last encoder block.
block_filters = int(
filters * (filters_rate ** (down_blocks + self.stem_blocks))
)

if block_contraction:
# Contract the channels with an exponent lower than the last encoder block.
block_filters = int(
self.filters
* (
self.filters_rate
** (self.down_blocks + self.stem_blocks - 1)
)
)
else:
# Keep the block output filters the same.
block_filters = int(
self.filters
* (self.filters_rate ** (self.down_blocks + self.stem_blocks))
)
self.encoder_stack.append(
SimpleConvBlock(
in_channels=after_block_filters,
Expand Down
29 changes: 29 additions & 0 deletions tests/architectures/test_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,32 @@ def test_simple_upsampling_block():
z = block(x, feature=feature)

assert z.shape == (5, 64, 200, 200)

block = SimpleUpsamplingBlock(
x_in_shape=10,
current_stride=1,
upsampling_stride=2,
interp_method="bilinear",
refine_convs=2,
refine_convs_filters=64,
refine_convs_kernel_size=3,
refine_convs_use_bias=True,
refine_convs_batch_norm=True,
refine_convs_batch_norm_before_activation=False,
refine_convs_activation="relu",
up_interpolate=False,
transpose_convs_filters=5,
transpose_convs_batch_norm=True,
transpose_convs_batch_norm_before_activation=True,
)
print(block)

block = block.to(device)
block.eval()

x = torch.rand(5, 5, 100, 100).to(device)
feature = torch.rand(5, 5, 200, 200).to(device)

z = block(x, feature=feature)

assert z.shape == (5, 64, 200, 200)
27 changes: 27 additions & 0 deletions tests/architectures/test_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,30 @@ def test_unet_reference():
with torch.no_grad():
z = conv2d(y["outputs"][-1])
assert z.shape == (1, 13, 192, 192)

# Test number of intermediate features outputted from encoder.
enc = Encoder(
in_channels=1,
filters=filters,
down_blocks=down_blocks,
filters_rate=filters_rate,
current_stride=2,
convs_per_block=convs_per_block,
kernel_size=kernel_size,
block_contraction=True,
)
print(enc)

enc = enc.to(device)
enc.eval()

x = torch.rand(1, 1, 192, 192).to(device)
with torch.no_grad():
y, features = enc(x)

assert y.shape == (1, 128, 12, 12)
assert len(features) == 4
assert features[0].shape == (1, 128, 24, 24)
assert features[1].shape == (1, 64, 48, 48)
assert features[2].shape == (1, 32, 96, 96)
assert features[3].shape == (1, 16, 192, 192)

0 comments on commit b7ad774

Please sign in to comment.