Skip to content

Commit 0dfe3ca

Browse files
authored
test file for FCNNetwork added (aeon-toolkit#2559)
1 parent a0c58dc commit 0dfe3ca

File tree

1 file changed

+196
-0
lines changed

1 file changed

+196
-0
lines changed

aeon/networks/tests/test_fcn.py

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""Test for the FCNNetwork class."""
2+
3+
import pytest
4+
5+
from aeon.networks import FCNNetwork
6+
from aeon.utils.validation._dependencies import _check_soft_dependencies
7+
8+
9+
@pytest.mark.skipif(
10+
not _check_soft_dependencies(["tensorflow"], severity="none"),
11+
reason="Tensorflow soft dependency unavailable.",
12+
)
13+
def test_fcnnetwork_valid():
14+
"""Test FCNNetwork with valid configurations."""
15+
input_shape = (100, 5)
16+
model = FCNNetwork(n_layers=3)
17+
input_layer, output_layer = model.build_network(input_shape)
18+
19+
assert hasattr(input_layer, "shape")
20+
assert hasattr(output_layer, "shape")
21+
22+
23+
@pytest.mark.skipif(
24+
not _check_soft_dependencies(["tensorflow"], severity="none"),
25+
reason="Tensorflow soft dependency unavailable.",
26+
)
27+
@pytest.mark.parametrize(
28+
"activation, should_raise",
29+
[
30+
(["relu", "sigmoid", "tanh"], False),
31+
(["relu", "sigmoid"], True),
32+
(
33+
["relu", "sigmoid", "tanh", "softmax"],
34+
True,
35+
),
36+
("relu", False),
37+
("sigmoid", False),
38+
("tanh", False),
39+
("softmax", False),
40+
],
41+
)
42+
def test_fcnnetwork_activation(activation, should_raise):
43+
"""Test FCNNetwork with valid and invalid activation configurations."""
44+
input_shape = (100, 5)
45+
if should_raise:
46+
with pytest.raises(ValueError):
47+
model = FCNNetwork(activation=activation)
48+
model.build_network(input_shape)
49+
else:
50+
model = FCNNetwork(activation=activation)
51+
input_layer, output_layer = model.build_network(input_shape)
52+
53+
assert hasattr(input_layer, "shape")
54+
55+
assert hasattr(output_layer, "shape")
56+
57+
58+
@pytest.mark.skipif(
59+
not _check_soft_dependencies(["tensorflow"], severity="none"),
60+
reason="Tensorflow soft dependency unavailable.",
61+
)
62+
@pytest.mark.parametrize(
63+
"kernel_size, should_raise",
64+
[
65+
([3, 1, 2], False),
66+
([1, 3], True),
67+
([3, 1, 1, 3], True),
68+
(3, False),
69+
],
70+
)
71+
def test_fcnnetwork_kernel_size(kernel_size, should_raise):
72+
"""Test FCNNetwork with valid and invalid kernel_size configurations."""
73+
input_shape = (100, 5)
74+
if should_raise:
75+
with pytest.raises(ValueError):
76+
model = FCNNetwork(kernel_size=kernel_size, n_layers=3)
77+
model.build_network(input_shape)
78+
else:
79+
model = FCNNetwork(kernel_size=kernel_size, n_layers=3)
80+
input_layer, output_layer = model.build_network(input_shape)
81+
82+
assert hasattr(input_layer, "shape")
83+
assert hasattr(output_layer, "shape")
84+
85+
86+
@pytest.mark.skipif(
87+
not _check_soft_dependencies(["tensorflow"], severity="none"),
88+
reason="Tensorflow soft dependency unavailable.",
89+
)
90+
@pytest.mark.parametrize(
91+
"dilation_rate, should_raise",
92+
[
93+
([1, 2, 1], False),
94+
([1, 4], True),
95+
([1, 2, 4, 1], True),
96+
(1, False),
97+
],
98+
)
99+
def test_fcnnetwork_dilation_rate(dilation_rate, should_raise):
100+
"""Test FCNNetwork with valid and invalid dilation_rate configurations."""
101+
input_shape = (100, 5)
102+
if should_raise:
103+
with pytest.raises(ValueError):
104+
model = FCNNetwork(dilation_rate=dilation_rate, n_layers=3)
105+
model.build_network(input_shape)
106+
else:
107+
model = FCNNetwork(dilation_rate=dilation_rate, n_layers=3)
108+
input_layer, output_layer = model.build_network(input_shape)
109+
110+
assert hasattr(input_layer, "shape")
111+
assert hasattr(output_layer, "shape")
112+
113+
114+
@pytest.mark.skipif(
115+
not _check_soft_dependencies(["tensorflow"], severity="none"),
116+
reason="Tensorflow soft dependency unavailable.",
117+
)
118+
@pytest.mark.parametrize(
119+
"strides, should_raise",
120+
[
121+
([1, 2, 3], False),
122+
([1, 1], True),
123+
([1, 2, 2, 1], True),
124+
(1, False),
125+
],
126+
)
127+
def test_fcnnetwork_strides(strides, should_raise):
128+
"""Test FCNNetwork with valid and invalid strides configurations."""
129+
input_shape = (100, 5)
130+
if should_raise:
131+
with pytest.raises(ValueError):
132+
model = FCNNetwork(strides=strides, n_layers=3)
133+
model.build_network(input_shape)
134+
else:
135+
model = FCNNetwork(strides=strides, n_layers=3)
136+
input_layer, output_layer = model.build_network(input_shape)
137+
138+
assert hasattr(input_layer, "shape")
139+
assert hasattr(output_layer, "shape")
140+
141+
142+
@pytest.mark.skipif(
143+
not _check_soft_dependencies(["tensorflow"], severity="none"),
144+
reason="Tensorflow soft dependency unavailable.",
145+
)
146+
@pytest.mark.parametrize(
147+
"padding, should_raise",
148+
[
149+
(["same", "same", "valid"], False),
150+
(["valid", "same"], True),
151+
(["same", "valid", "same", "valid"], True),
152+
("same", False),
153+
("valid", False),
154+
],
155+
)
156+
def test_fcnnetwork_padding(padding, should_raise):
157+
"""Test FCNNetwork with valid and invalid padding configurations."""
158+
input_shape = (100, 5)
159+
if should_raise:
160+
with pytest.raises(ValueError):
161+
model = FCNNetwork(padding=padding, n_layers=3)
162+
model.build_network(input_shape)
163+
else:
164+
model = FCNNetwork(padding=padding, n_layers=3)
165+
input_layer, output_layer = model.build_network(input_shape)
166+
167+
assert hasattr(input_layer, "shape")
168+
assert hasattr(output_layer, "shape")
169+
170+
171+
@pytest.mark.skipif(
172+
not _check_soft_dependencies(["tensorflow"], severity="none"),
173+
reason="Tensorflow soft dependency unavailable.",
174+
)
175+
@pytest.mark.parametrize(
176+
"n_filters, should_raise",
177+
[
178+
([32, 64, 128], False), # Valid case with a list of filters
179+
([32, 64], True), # Invalid case with fewer filters than layers
180+
([32, 64, 128, 256], True), # Invalid case with more filters than layers
181+
(32, False), # Valid case with a single filter value
182+
],
183+
)
184+
def test_fcnnetwork_n_filters(n_filters, should_raise):
185+
"""Test FCNNetwork with valid and invalid n_filters configurations."""
186+
input_shape = (100, 5)
187+
if should_raise:
188+
with pytest.raises(ValueError):
189+
model = FCNNetwork(n_filters=n_filters, n_layers=3)
190+
model.build_network(input_shape)
191+
else:
192+
model = FCNNetwork(n_filters=n_filters, n_layers=3)
193+
input_layer, output_layer = model.build_network(input_shape)
194+
195+
assert hasattr(input_layer, "shape")
196+
assert hasattr(output_layer, "shape")

0 commit comments

Comments
 (0)