Skip to content

Commit

Permalink
All test passed
Browse files Browse the repository at this point in the history
  • Loading branch information
KoraTMontemagno committed Aug 14, 2024
1 parent 41aa7ce commit 14f22e6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 84 deletions.
6 changes: 3 additions & 3 deletions docs/source/notebooks/0.2-Creating_networks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@
"\n",
"attributes = (parameters, parameters, parameters)\n",
"edges = (\n",
" AdjacencyLists(0, (1,), None, None, None),\n",
" AdjacencyLists(2, None, (2,), (0,), None),\n",
" AdjacencyLists(2, None, None, None, (1,)),\n",
" AdjacencyLists(0, (1,), None, None, None, (None,)),\n",
" AdjacencyLists(2, None, (2,), (0,), None, (None,)),\n",
" AdjacencyLists(2, None, None, None, (1,), (None,)),\n",
")"
]
},
Expand Down
12 changes: 6 additions & 6 deletions src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,10 @@ def add_nodes(
"'DP-state', 'continuous-state', 'binary-state', 'ef-normal'."
)
)

# assess children number
children_number = 1
if value_children == None:
if value_children is None:
children_number = 0
elif isinstance(value_children, int):
children_number = 1
Expand Down Expand Up @@ -660,10 +660,10 @@ def add_nodes(
if children_number != len(coupling_fn):
if coupling_fn == (None,):
coupling_fn = children_number * coupling_fn
else:
raise ValueError("""The number of coupling functions
and value children do not match""")

else:
raise ValueError(
"The number of coupling fn and value children do not match"
)

# add a new edge
edges_as_list.append(
Expand Down
91 changes: 16 additions & 75 deletions tests/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from jax.tree_util import Partial

from pyhgf import load_data
from pyhgf.model import Network
from pyhgf.math import gaussian_surprise
from pyhgf.model import Network
from pyhgf.typing import AdjacencyLists, Inputs
from pyhgf.updates.posterior.continuous import (
continuous_node_update,
Expand Down Expand Up @@ -131,22 +131,16 @@ def test_continuous_input_update(nodes_attributes):
# one value parent with one volatility parent #
###############################################
attributes = nodes_attributes

def identity(x):
return(x)
return x

edges = (
AdjacencyLists(0, (1,), None, None, None, (None,)),
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
AdjacencyLists(2, None, None, None, (1,), (None,)),
)

# define a non linear network behaving like the linear one
edges_nonlinear = (
AdjacencyLists(0, (1,), None, None, None, (None,)),
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
AdjacencyLists(2, None, None, None, (1,), (identity,)),
)

# create update sequence
sequence1 = 1, continuous_node_prediction
sequence2 = 2, continuous_node_prediction
Expand Down Expand Up @@ -175,14 +169,6 @@ def identity(x):
input_data=(data, time_steps, observed),
)

#repeat for non-linear network
new_attributes_nonlinear, _ = beliefs_propagation(
structure=(inputs, edges_nonlinear),
attributes=attributes,
update_sequence=update_sequence,
input_data=(data, time_steps, observed),
)

for idx, val in zip(["time_step", "values"], [1.0, 0.2]):
assert jnp.isclose(new_attributes[0][idx], val)
for idx, val in zip(
Expand All @@ -196,26 +182,15 @@ def identity(x):
):
assert jnp.isclose(new_attributes[2][idx], val)

for idx, val in zip(["time_step", "values"], [1.0, 0.2]):
assert jnp.isclose(new_attributes_nonlinear[0][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[10000.881, 0.880797, 0.20007047, 1.0],
):
assert jnp.isclose(new_attributes_nonlinear[1][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[0.9794834, 0.95257413, 0.97345114, 1.0],
):
assert jnp.isclose(new_attributes_nonlinear[2][idx], val)

def test_continuous_input_update_nonlinear(nodes_attributes):
###############################################
# one value parent with one volatility parent #
###############################################
attributes = nodes_attributes

def identity(x):
return(x)
return x

# define a non linear network behaving like the linear one
edges_nonlinear = (
Expand Down Expand Up @@ -278,16 +253,6 @@ def test_scan_loop(nodes_attributes):
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
AdjacencyLists(2, None, None, None, (1,), (None,)),
)

#testing it with a non-linear coupling
def identity(x):
return(x)

edges_nonlinear = (
AdjacencyLists(0, (1,), None, None, None, (None,)),
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
AdjacencyLists(2, None, None, None, (1,), (identity,)),
)

# create update sequence
sequence1 = 2, continuous_node_prediction
Expand All @@ -314,12 +279,6 @@ def identity(x):
structure=(inputs, edges),
)

scan_fn_nonlinear = Partial(
beliefs_propagation,
update_sequence=update_sequence,
structure=(inputs, edges),
)

# Create the data (value and time steps vectors)
time_steps = jnp.ones((len(timeserie), 1))
observed = jnp.ones((len(timeserie), 1))
Expand All @@ -339,22 +298,6 @@ def identity(x):
):
assert jnp.isclose(last[2][idx], val)

# non linear coupling
# Run the entire for loop
last_nonlinear, _ = scan(scan_fn_nonlinear, attributes,
(timeserie, time_steps, observed))
for idx, val in zip(["time_step", "values"], [1.0, 0.8241]):
assert jnp.isclose(last_nonlinear[0][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[24557.84, 14557.839, 0.8041823, 0.79050046],
):
assert jnp.isclose(last_nonlinear[1][idx], val)
for idx, val in zip(
["precision", "expected_precision", "mean", "expected_mean"],
[1.3334407, 1.3493799, -7.1686087, -7.509615],
):
assert jnp.isclose(last_nonlinear[2][idx], val)

def test_scan_loop_nonlinear(nodes_attributes):
timeserie = load_data("continuous")
Expand All @@ -364,10 +307,10 @@ def test_scan_loop_nonlinear(nodes_attributes):
###############################################
attributes = nodes_attributes

#defining an identity function
# defining an identity function
def identity(x):
return(x)
return x

edges_nonlinear = (
AdjacencyLists(0, (1,), None, None, None, (None,)),
AdjacencyLists(2, None, (2,), (0,), None, (None,)),
Expand Down Expand Up @@ -404,8 +347,9 @@ def identity(x):
observed = jnp.ones((len(timeserie), 1))

# Run the entire for loop
last_nonlinear, _ = scan(scan_fn_nonlinear, attributes,
(timeserie, time_steps, observed))
last_nonlinear, _ = scan(
scan_fn_nonlinear, attributes, (timeserie, time_steps, observed)
)
for idx, val in zip(["time_step", "values"], [1.0, 0.8241]):
assert jnp.isclose(last_nonlinear[0][idx], val)
for idx, val in zip(
Expand All @@ -421,34 +365,31 @@ def identity(x):


def test_coupling_fn_multiple_children():
""" Tests if the coupling function is passed correctly
"""Tests if the coupling function is passed correctly
into the network"""

# creating a simple coupling function
def identity(x):
return x

# I create a network with a node with 2 value children
test_HGF = (
Network()
.add_nodes(kind="continuous-input")
.add_nodes(kind="continuous-input")
.add_nodes(value_children=0, n_nodes=1)
.add_nodes(value_children=[1, 2], n_nodes=1,
coupling_fn=(None,identity)
)
.add_nodes(value_children=[1, 2], n_nodes=1, coupling_fn=(None, identity))
)

# check if the number of coupling fn matches the number of children
coupling_fn_length = []
children_number = []
for node_idx in range(2,len(test_HGF.edges)):
for node_idx in range(2, len(test_HGF.edges)):
coupling_fn_length.append(len(test_HGF.edges[node_idx].coupling_fn))
children_number.append(len(test_HGF.edges[node_idx].value_children))

assert children_number == coupling_fn_length



if __name__ == "__main__":
unittest.main(argv=["first-arg-is-ignored"], exit=False)

0 comments on commit 14f22e6

Please sign in to comment.