Skip to content

Commit a13d063

Browse files
authored
[ENH] Test coverage for Resnet Network (aeon-toolkit#2553)
* Resnet pytest * Resnet pytest * Fixed tensorflow failing * Added Resnet in function name
1 parent 7518feb commit a13d063

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

aeon/networks/tests/test_resnet.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Tests for the ResNet Model."""
2+
3+
import pytest
4+
5+
from aeon.networks import ResNetNetwork
6+
from aeon.utils.validation._dependencies import _check_soft_dependencies
7+
8+
9+
@pytest.mark.skipif(
10+
not _check_soft_dependencies(["tensorflow"], severity="none"),
11+
reason="skip test if required soft dependency not available",
12+
)
13+
def test_resnet_default_initialization():
14+
"""Test if the network initializes with proper attributes."""
15+
model = ResNetNetwork()
16+
assert isinstance(
17+
model, ResNetNetwork
18+
), "Model initialization failed: Incorrect type"
19+
assert model.n_residual_blocks == 3, "Default residual blocks count mismatch"
20+
assert (
21+
model.n_conv_per_residual_block == 3
22+
), "Default convolution blocks count mismatch"
23+
assert model.n_filters is None, "Default n_filters should be None"
24+
assert model.kernel_size is None, "Default kernel_size should be None"
25+
assert model.strides == 1, "Default strides value mismatch"
26+
assert model.dilation_rate == 1, "Default dilation rate mismatch"
27+
assert model.activation == "relu", "Default activation mismatch"
28+
assert model.use_bias is True, "Default use_bias mismatch"
29+
assert model.padding == "same", "Default padding mismatch"
30+
31+
32+
@pytest.mark.skipif(
33+
not _check_soft_dependencies(["tensorflow"], severity="none"),
34+
reason="skip test if required soft dependency not available",
35+
)
36+
def test_resnet_custom_initialization():
37+
"""Test whether custom kwargs are correctly set."""
38+
model = ResNetNetwork(
39+
n_residual_blocks=3,
40+
n_conv_per_residual_block=3,
41+
n_filters=[64, 128, 128],
42+
kernel_size=[8, 5, 3],
43+
activation="relu",
44+
strides=1,
45+
padding="same",
46+
)
47+
model.build_network((128, 1))
48+
assert isinstance(
49+
model, ResNetNetwork
50+
), "Custom initialization failed: Incorrect type"
51+
assert model._n_filters == [64, 128, 128], "n_filters list mismatch"
52+
assert model._kernel_size == [8, 5, 3], "kernel_size list mismatch"
53+
assert model._activation == ["relu", "relu", "relu"], "activation list mismatch"
54+
assert model._strides == [1, 1, 1], "strides list mismatch"
55+
assert model._padding == ["same", "same", "same"], "padding list mismatch"
56+
57+
58+
@pytest.mark.skipif(
59+
not _check_soft_dependencies(["tensorflow"], severity="none"),
60+
reason="skip test if required soft dependency not available",
61+
)
62+
def test_resnet_invalid_initialization():
63+
"""Test if the network raises valid exceptions for invalid configurations."""
64+
with pytest.raises(ValueError, match=".*same as number of residual blocks.*"):
65+
ResNetNetwork(n_filters=[64, 128], n_residual_blocks=3).build_network((128, 1))
66+
67+
with pytest.raises(ValueError, match=".*same as number of convolution layers.*"):
68+
ResNetNetwork(kernel_size=[8, 5], n_conv_per_residual_block=3).build_network(
69+
(128, 1)
70+
)
71+
72+
with pytest.raises(ValueError, match=".*same as number of convolution layers.*"):
73+
ResNetNetwork(strides=[1, 2], n_conv_per_residual_block=3).build_network(
74+
(128, 1)
75+
)
76+
77+
78+
@pytest.mark.skipif(
79+
not _check_soft_dependencies(["tensorflow"], severity="none"),
80+
reason="skip test if required soft dependency not available",
81+
)
82+
def test_resnet_build_network():
83+
"""Test network building with various input shapes."""
84+
model = ResNetNetwork()
85+
86+
input_shapes = [(128, 1), (256, 3), (512, 1)]
87+
for shape in input_shapes:
88+
input_layer, output_layer = model.build_network(shape)
89+
assert hasattr(input_layer, "shape"), "Input layer type mismatch"
90+
assert hasattr(output_layer, "shape"), "Output layer type mismatch"
91+
assert input_layer.shape[1:] == shape, "Input shape mismatch"
92+
assert output_layer.shape[-1] == 128, "Output layer mismatch"
93+
94+
95+
@pytest.mark.skipif(
96+
not _check_soft_dependencies(["tensorflow"], severity="none"),
97+
reason="skip test if required soft dependency not available",
98+
)
99+
def test_resnet_shortcut_layer():
100+
"""Test the shortcut layer functionality."""
101+
model = ResNetNetwork()
102+
103+
input_shape = (128, 64)
104+
input_layer, output_layer = model.build_network(input_shape)
105+
106+
shortcut = model._shortcut_layer(input_layer, output_layer)
107+
108+
assert hasattr(shortcut, "shape"), "Shortcut layer output type mismatch"
109+
assert shortcut.shape[-1] == 128, "Shortcut output shape mismatch"

0 commit comments

Comments
 (0)