Skip to content

Add a custom Sequential network to avoid issues with building and serialization in keras #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .point_inference_network import PointInferenceNetwork
from .mlp import MLP
from .fusion_network import FusionNetwork
from .sequential import Sequential
from .summary_network import SummaryNetwork
from .time_series_network import TimeSeriesNetwork
from .transformers import SetTransformer, TimeSeriesTransformer, FusionTransformer
Expand Down
42 changes: 11 additions & 31 deletions bayesflow/networks/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

import keras

from bayesflow.utils import sequential_kwargs
from bayesflow.utils.serialization import deserialize, serializable, serialize
from bayesflow.utils import layer_kwargs
from bayesflow.utils.serialization import serializable, serialize

from ..sequential import Sequential
from ..residual import Residual


@serializable("bayesflow.networks")
class MLP(keras.Sequential):
class MLP(Sequential):
"""
Implements a simple configurable MLP with optional residual connections and dropout.

Expand Down Expand Up @@ -67,40 +68,19 @@ def __init__(
self.norm = norm
self.spectral_normalization = spectral_normalization

layers = []
blocks = []

for width in widths:
layer = self._make_layer(
block = self._make_block(
width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization
)
layers.append(layer)

super().__init__(layers, **sequential_kwargs(kwargs))

def build(self, input_shape=None):
if self.built:
# building when the network is already built can cause issues with serialization
# see https://github.com/keras-team/keras/issues/21147
return

# we only care about the last dimension, and using ... signifies to keras.Sequential
# that any number of batch dimensions is valid (which is what we want for all sublayers)
# we also have to avoid calling super().build() because this causes
# shape errors when building on non-sets but doing inference on sets
# this is a work-around for https://github.com/keras-team/keras/issues/21158
input_shape = (..., input_shape[-1])

for layer in self._layers:
layer.build(input_shape)
input_shape = layer.compute_output_shape(input_shape)
blocks.append(block)

@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))
super().__init__(*blocks, **kwargs)

def get_config(self):
base_config = super().get_config()
base_config = sequential_kwargs(base_config)
base_config = layer_kwargs(base_config)

config = {
"widths": self.widths,
Expand All @@ -115,7 +95,7 @@ def get_config(self):
return base_config | serialize(config)

@staticmethod
def _make_layer(width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization):
def _make_block(width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization):
layers = []

dense = keras.layers.Dense(width, kernel_initializer=kernel_initializer)
Expand Down Expand Up @@ -148,4 +128,4 @@ def _make_layer(width, activation, kernel_initializer, residual, dropout, norm,
if residual:
return Residual(*layers)

return keras.Sequential(layers)
return Sequential(layers)
4 changes: 3 additions & 1 deletion bayesflow/networks/residual/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from bayesflow.utils import sequential_kwargs
from bayesflow.utils.serialization import deserialize, serializable, serialize

from ..sequential import Sequential


@serializable("bayesflow.networks")
class Residual(keras.Sequential):
class Residual(Sequential):
def __init__(self, *layers: keras.Layer, **kwargs):
if len(layers) == 1 and isinstance(layers[0], Sequence):
layers = layers[0]
Expand Down
1 change: 1 addition & 0 deletions bayesflow/networks/sequential/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .sequential import Sequential
88 changes: 88 additions & 0 deletions bayesflow/networks/sequential/sequential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from collections.abc import Sequence
import keras

from bayesflow.utils import layer_kwargs
from bayesflow.utils.serialization import deserialize, serializable, serialize


@serializable("bayesflow.networks")
class Sequential(keras.Layer):
"""
A custom sequential model for managing a sequence of Keras layers.

This class extends `keras.Layer` and provides functionality for building,
calling, and serializing a sequence of layers. Unlike `keras.Sequential`,
this implementation does not eagerly check input shapes, meaning it is
compatible with both single inputs and sets.

Parameters
----------
layers : keras.layer | Sequence[keras.layer]
A sequence of Keras layers to be managed by this model.
Can be passed by unpacking or as a single sequence.
**kwargs :
Additional keyword arguments passed to the base `keras.Layer` class.

Notes
-----
- This class differs from `keras.Sequential` in that it does not eagerly check
input shapes. This means that it is compatible with both single inputs
and sets.
"""

def __init__(self, *layers: keras.Layer | Sequence[keras.Layer], **kwargs):
super().__init__(**layer_kwargs(kwargs))
if len(layers) == 1 and isinstance(layers[0], Sequence):
layers = layers[0]

self._layers = layers

def build(self, input_shape):
if self.built:
# building when the network is already built can cause issues with serialization
# see https://github.com/keras-team/keras/issues/21147
return

for layer in self._layers:
layer.build(input_shape)
input_shape = layer.compute_output_shape(input_shape)

def call(self, inputs, training=None, mask=None):
x = inputs
for layer in self._layers:
kwargs = self._make_kwargs_for_layer(layer, training, mask)
x = layer(x, **kwargs)
return x

def compute_output_shape(self, input_shape):
for layer in self._layers:
input_shape = layer.compute_output_shape(input_shape)

return input_shape

def get_config(self):
base_config = super().get_config()
base_config = layer_kwargs(base_config)

config = {
"layers": [serialize(layer) for layer in self._layers],
}

return base_config | config

@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))

@property
def layers(self):
return self._layers

@staticmethod
def _make_kwargs_for_layer(layer, training, mask):
kwargs = {}
if layer._call_has_mask_arg:
kwargs["mask"] = mask
if layer._call_has_training_arg and training is not None:
kwargs["training"] = training
return kwargs
19 changes: 17 additions & 2 deletions tests/test_networks/test_mlp/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,24 @@
from bayesflow.networks import MLP


@pytest.fixture(params=[None, 0.0, 0.1])
def dropout(request):
return request.param


@pytest.fixture(params=[None, "batch"])
def norm(request):
return request.param


@pytest.fixture(params=[False, True])
def residual(request):
return request.param


@pytest.fixture()
def mlp():
return MLP([64, 64])
def mlp(dropout, norm, residual):
return MLP([64, 64], dropout=dropout, norm=norm, residual=residual)


@pytest.fixture()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_networks/test_mlp/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from bayesflow.utils.serialization import deserialize, serialize

from ...utils import assert_models_equal
from ...utils import assert_layers_equal


def test_serialize_deserialize(mlp, build_shapes):
Expand All @@ -21,4 +21,4 @@ def test_save_and_load(tmp_path, mlp, build_shapes):
keras.saving.save_model(mlp, tmp_path / "model.keras")
loaded = keras.saving.load_model(tmp_path / "model.keras")

assert_models_equal(mlp, loaded)
assert_layers_equal(mlp, loaded)