Skip to content

Fix: optimizer was not used in workflow with multiple fits #549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 5, 2025

Conversation

vpratz
Copy link
Collaborator

@vpratz vpratz commented Aug 4, 2025

For the optimizer to be used, the approximator.compile function has to be called. This was not the case for repeated calls to fit_..., even with keep_optimizer set to False. I adapted the setup_optimizer function to match the description in its docstring, and made the compilation conditional on its output. The output indicates if a new optimizer was configured.

To reproduce, you can run the following code. In the version without the fix, one and two will give the same value, and three a different one. With the fix, all three produce different values, as expected.

import bayesflow as bf
import keras

epochs = 2
num_batches_per_epoch = 10
batch_size = 4

simulator = bf.simulators.TwoMoons()
validation_data = simulator.sample(100)

workflow = bf.BasicWorkflow(
    simulator=simulator,
    standardize="all",
    inference_conditions="observables",
    inference_variables="parameters",
)

workflow.fit_online(
    num_batches_per_epoch=num_batches_per_epoch, epochs=epochs, validation_data=validation_data, batch_size=batch_size
)
print("first", workflow.approximator.trainable_weights[0][0])
workflow.fit_online(
    num_batches_per_epoch=num_batches_per_epoch, epochs=epochs, validation_data=validation_data, batch_size=batch_size
)
print("second", workflow.approximator.trainable_weights[0][0])
workflow.approximator.compile(optimizer=keras.optimizers.Adam())
workflow.fit_online(
    num_batches_per_epoch=num_batches_per_epoch, epochs=epochs, validation_data=validation_data, batch_size=batch_size
)
print("third", workflow.approximator.trainable_weights[0][0])

For the optimizer to be used, the approximator.compile function has to
be called. This was not the case. I adapted the `setup_optimizer`
function to match the description in its docstring, and made the
compilation conditional on its output. The output indicates if a new
optimizer was configured.
@vpratz vpratz requested review from stefanradev93 and LarsKue August 4, 2025 11:53
@vpratz vpratz added the fix Pull request that fixes a bug label Aug 4, 2025
@stefanradev93
Copy link
Contributor

stefanradev93 commented Aug 4, 2025

The serialization tests seem to be suddenly failing for the multimodal network? Apart from that, the fix look good to me.

@vpratz
Copy link
Collaborator Author

vpratz commented Aug 4, 2025

Thanks for taking a look! I will check if I can reproduce the issue locally with updated dependencies and try to fix the failure.

@vpratz
Copy link
Collaborator Author

vpratz commented Aug 4, 2025

The error concerns all summary networks and all three backends. I could only reproduce it after an update of my dependencies, my suspicion is that the handling of DType in Keras might have changed, but I have not taken steps to verify this yet. I will take a look...

The FusionNetwork just happens to be the first and JAX the fastest in the tests...

@vpratz
Copy link
Collaborator Author

vpratz commented Aug 4, 2025

The issue does not exist with Keras 3.10 and appears when upgrading to Keras 3.11.

@vpratz
Copy link
Collaborator Author

vpratz commented Aug 4, 2025

This commit introduces the regression: keras-team/keras@24f104e, more specifically the code that was removed in this block and not moved to the Layer class.

@vpratz
Copy link
Collaborator Author

vpratz commented Aug 4, 2025

Reintroducing the removed code into Keras fixes the issue. How do we want to proceed? Do we want to try to get a fix included into Keras? With the changes in #500 we should also be able to work around the issue... We might also want to set keras <= 3.10 until the issue is resolved.

diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py
index edda5a01d..355492a83 100644
--- a/keras/src/ops/operation.py
+++ b/keras/src/ops/operation.py
@@ -128,6 +128,17 @@ class Operation(KerasSaveable):
         arg_names = inspect.getfullargspec(cls.__init__).args
         kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
 
+        # Explicitly serialize `dtype` to support auto_config
+        dtype = kwargs.get("dtype", None)
+        if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy):
+            # For backward compatibility, we use a str (`name`) for
+            # `DTypePolicy`
+            if dtype.quantization_mode is None:
+                kwargs["dtype"] = dtype.name
+            # Otherwise, use `dtype_policies.serialize`
+            else:
+                kwargs["dtype"] = dtype_policies.serialize(dtype)
+
         # For safety, we only rely on auto-configs for a small set of
         # serializable types.
         supported_types = (str, int, float, bool, type(None))

@stefanradev93
Copy link
Contributor

Wait, does this break keras serialization in general or is it an artifact from using our custom monkey patch? As far as I can see, the commit simply removes the dtype from the constructor and simplifies the serialization of dtypes.

The extra call leads to the DTypePolicy to be deserialized. This is then
passed as a class, and cannot be handled by autoconf, leading to the
error discussed in
#549
@vpratz
Copy link
Collaborator Author

vpratz commented Aug 5, 2025

Good call, @stefanradev93. The problem arises here because we deserialize the config before passing it to the constructor. This instantiates the DTypePolicy as an object. As the autoconf mechanism doesn't handle objects of that type, this leads to the error.

Removing the deserialize call from the SummaryNetwork seems to fix the problem.
This could cause problems downstream when someone was overriding get_config but not from_config, because it changes the default from_config for summary networks. To counter this, we could check if get_config was implemented so that autoconf is not used, and only apply deserialize to the config if this is the case, like this:

    @classmethod
    def from_config(cls, config, custom_objects=None):
        if hasattr(cls.get_config, "_is_default") and cls.get_config._is_default:
            return cls(**config)
        return cls(**deserialize(config, custom_objects=custom_objects))

Copy link

codecov bot commented Aug 5, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
bayesflow/networks/summary_network.py 91.89% <100.00%> (-2.40%) ⬇️
bayesflow/workflows/basic_workflow.py 69.94% <100.00%> (+0.17%) ⬆️

@vpratz
Copy link
Collaborator Author

vpratz commented Aug 5, 2025

Merging this now to make the state of the dev branch fully functional again. If I interpret it correctly, not many users will encounter the issue, as it should only arise if you load a model and then save it again. Nevertheless, I think it would be good to do a bugfix release soon.
What do you think, @stefanradev93 and @LarsKue?

@vpratz vpratz merged commit 952862c into dev Aug 5, 2025
9 checks passed
@vpratz vpratz deleted the fix-workflow-optimizer branch August 5, 2025 08:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix Pull request that fixes a bug
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants