Skip to content

Enable use of summary networks with functional API again #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bayesflow/links/ordered.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.utils import layer_kwargs
from bayesflow.utils.decorators import sanitize_input_shape


@serializable(package="links.ordered")
Expand Down Expand Up @@ -49,5 +50,6 @@ def call(self, inputs):
x = keras.ops.concatenate([below, anchor_input, above], self.axis)
return x

@sanitize_input_shape
def compute_output_shape(self, input_shape):
return input_shape
1 change: 1 addition & 0 deletions bayesflow/networks/summary_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def build(self, input_shape):
if self.base_distribution is not None:
self.base_distribution.build(keras.ops.shape(z))

@sanitize_input_shape
def compute_output_shape(self, input_shape):
return keras.ops.shape(self.call(keras.ops.zeros(input_shape)))

Expand Down
3 changes: 3 additions & 0 deletions bayesflow/networks/transformers/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from bayesflow.networks import MLP
from bayesflow.types import Tensor
from bayesflow.utils import layer_kwargs
from bayesflow.utils.decorators import sanitize_input_shape
from bayesflow.utils.serialization import serializable


Expand Down Expand Up @@ -122,8 +123,10 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -
return out

# noinspection PyMethodOverriding
@sanitize_input_shape
def build(self, seq_x_shape, seq_y_shape):
self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape))

@sanitize_input_shape
def compute_output_shape(self, seq_x_shape, seq_y_shape):
return keras.ops.shape(self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape)))
2 changes: 2 additions & 0 deletions bayesflow/networks/transformers/pma.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from bayesflow.networks import MLP
from bayesflow.types import Tensor
from bayesflow.utils import layer_kwargs
from bayesflow.utils.decorators import sanitize_input_shape
from bayesflow.utils.serialization import serializable

from .mab import MultiHeadAttentionBlock
Expand Down Expand Up @@ -125,5 +126,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
summaries = self.mab(seed_tiled, set_x_transformed, training=training, **kwargs)
return ops.reshape(summaries, (ops.shape(summaries)[0], -1))

@sanitize_input_shape
def compute_output_shape(self, input_shape):
return keras.ops.shape(self.call(keras.ops.zeros(input_shape)))
3 changes: 3 additions & 0 deletions bayesflow/networks/transformers/sab.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import keras

from bayesflow.types import Tensor
from bayesflow.utils.decorators import sanitize_input_shape
from bayesflow.utils.serialization import serializable

from .mab import MultiHeadAttentionBlock
Expand All @@ -16,6 +17,7 @@ class SetAttentionBlock(MultiHeadAttentionBlock):
"""

# noinspection PyMethodOverriding
@sanitize_input_shape
def build(self, input_set_shape):
self.call(keras.ops.zeros(input_set_shape))

Expand All @@ -42,5 +44,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
return super().call(input_set, input_set, training=training, **kwargs)

# noinspection PyMethodOverriding
@sanitize_input_shape
def compute_output_shape(self, input_set_shape):
return keras.ops.shape(self.call(keras.ops.zeros(input_set_shape)))
7 changes: 5 additions & 2 deletions bayesflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def callback(x):


def sanitize_input_shape(fn: Callable):
"""Decorator to replace the first dimension in input_shape with a dummy batch size if it is None"""
"""Decorator to replace the first dimension in ..._shape arguments with a dummy batch size if it is None"""

# The Keras functional API passes input_shape = (None, second_dim, third_dim, ...), which
# causes problems when constructions like self.call(keras.ops.zeros(input_shape)) are used
Expand All @@ -126,5 +126,8 @@ def callback(input_shape: Shape) -> Shape:
return tuple(input_shape)
return input_shape

fn = argument_callback("input_shape", callback)(fn)
args = inspect.getfullargspec(fn).args
for arg in args:
if arg.endswith("_shape"):
fn = argument_callback(arg, callback)(fn)
return fn
22 changes: 22 additions & 0 deletions tests/test_networks/test_summary_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,28 @@ def test_build(automatic, summary_network, random_set):
assert summary_network.variables, "Model has no variables."


@pytest.mark.parametrize("automatic", [True, False])
def test_build_functional_api(automatic, summary_network, random_set):
if summary_network is None:
pytest.skip(reason="Nothing to do, because there is no summary network.")

assert summary_network.built is False

inputs = keras.layers.Input(shape=keras.ops.shape(random_set)[1:])
outputs = summary_network(inputs)
model = keras.Model(inputs=inputs, outputs=outputs)

if automatic:
model(random_set)
else:
model.build(keras.ops.shape(random_set))

assert model.built is True

# check the model has variables
assert summary_network.variables, "Model has no variables."


def test_variable_batch_size(summary_network, random_set):
if summary_network is None:
pytest.skip(reason="Nothing to do, because there is no summary network.")
Expand Down