Skip to content

[BUG] FlowMatching checkpoints fail to load when residual connections are disabled #491

Closed
@simschaefer

Description

@simschaefer

I trained a set of networks with deactivated residual conntections in the inference net (FlowMatching):

inference_net = bf.networks.FlowMatching(subnet_kwargs={"residual": False})

summary_net = bf.networks.SetTransformer(dropout=0.01, num_seeds=2, summary_dim=32,embed_dim=(128, 128))

workflow = bf.BasicWorkflow(
simulator=simulator,
adapter=adapter,
initial_learning_rate=0.0005,
inference_network=inference_net,
summary_network=summary_net,
checkpoint_filepath=network_dir,
checkpoint_name=network_name,
inference_variables=model_specs['simulation_settings']['param_names']
)

history = workflow.fit_online(epochs=200, num_batches_per_epoch=1000, batch_size=16, validation_data=val_data)

After training, ineference works fine, but if I try to load the checkpoints:
approximator = keras.saving.load_model(network_dir)
I get this error:


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[70], line 12
      1 workflow = bf.BasicWorkflow(
      2     simulator=simulator,
      3     adapter=adapter,
   (...)      9     inference_variables=model_specs['simulation_settings']['param_names']
     10 )
---> 12 approximator = keras.saving.load_model(network_dir)

File ~/miniforge3/envs/bf_new/lib/python3.11/site-packages/keras/src/saving/saving_api.py:189, in load_model(filepath, custom_objects, compile, safe_mode)
    186         is_keras_zip = True
    188 if is_keras_zip or is_keras_dir or is_hf:
--> 189     return saving_lib.load_model(
    190         filepath,
    191         custom_objects=custom_objects,
    192         compile=compile,
    193         safe_mode=safe_mode,
    194     )
    195 if str(filepath).endswith((".h5", ".hdf5")):
    196     return legacy_h5_format.load_model_from_hdf5(
    197         filepath, custom_objects=custom_objects, compile=compile
    198     )

File ~/miniforge3/envs/bf_new/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:367, in load_model(filepath, custom_objects, compile, safe_mode)
    362     raise ValueError(
    363         "Invalid filename: expected a `.keras` extension. "
    364         f"Received: filepath={filepath}"
    365     )
    366 with open(filepath, "rb") as f:
--> 367     return _load_model_from_fileobj(
    368         f, custom_objects, compile, safe_mode
    369     )

File ~/miniforge3/envs/bf_new/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:509, in _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode)
    506             extract_dir.cleanup()
    508     if failed_saveables:
--> 509         _raise_loading_failure(error_msgs)
    510 return model

File ~/miniforge3/envs/bf_new/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:631, in _raise_loading_failure(error_msgs, warn_only)
    629     warnings.warn(msg)
    630 else:
--> 631     raise ValueError(msg)

ValueError: A total of 5 objects could not be loaded. Example error message for object <Dense name=dense_343, built=False>:

Layer 'dense_343' was never built and thus it doesn't have any variables. However the weights file lists 2 variables for this layer.
In most cases, this error indicates that either:

1. The layer is owned by a parent layer that implements a `build()` method, but calling the parent's `build()` method did NOT create the state of the child layer 'dense_343'. A `build()` method must create ALL state for the layer, including the state of any children layers.

2. You need to implement the `def build_from_config(self, config)` method on layer 'dense_343', to specify how to rebuild it during loading. In this case, you might also want to implement the method that generates the build config at saving time, `def get_build_config(self)`. The method `build_from_config()` is meant to create the state of the layer (i.e. its variables) upon deserialization.

List of objects that could not be loaded:
[<Dense name=dense_343, built=False>, <Dense name=dense_344, built=False>, <Dense name=dense_345, built=False>, <Dense name=dense_346, built=False>, <Dense name=dense_347, built=False>]


The error appears to be caused by disabling the residual connections — when they are kept enabled during training, the checkpoints load without any issues.

Environment
• OS: Debian GNU/Linux 12 (bookworm)
• Keras: 3.9.2
• Python Version: 3.11.11
• Backend: Torch 2.6.0
• BayesFlow Version: 2.0.3

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions