Skip to content

Fix: move additional metrics from approximator to networks #500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: dev
Choose a base branch
from

Conversation

vpratz
Copy link
Collaborator

@vpratz vpratz commented May 30, 2025

Supplying the additional metrics for inference and summary networks via the approximators compile method caused problems during deseralization (#497). This can be resolved nicely by moving the metrics directly to the networks' constructors, analogous to how Keras normally handles custom metrics in layers.

As summary networks and inference networks inherit from the respective base classes, this change only requires minor adaptations. Calls to layer_kwargs are now only used in classes that directly inherit from keras.Layer, and have been moved into InferenceNetwork and SummaryNetwork.

Fixes #497.

Copy link

codecov bot commented May 30, 2025

Supplying the additional metrics for inference and summary networks
via the approximators compile method caused problems during
deseralization (#497). This can be resolved nicely by moving the metrics
directly to the networks' constructors, analogous to how Keras normally
handles custom metrics in layers.

As summary networks and inference networks inherit from the respective
base classes, this change only requires minor adaptations.
@vpratz vpratz force-pushed the fix-additional-metrics branch from 41e2967 to 8d296e9 Compare May 30, 2025 09:39
@vpratz vpratz changed the title Fix: move additional metrics from approximator to networks [WIP] Fix: move additional metrics from approximator to networks May 30, 2025
@vpratz vpratz force-pushed the fix-additional-metrics branch 2 times, most recently from 023db06 to 4c7a5a2 Compare May 30, 2025 13:22
This change makes it more capable for our purposes by allowing any
serializable value, not only the base types in the auto-config.
We have to check if this brings any footguns/downsides, or whether
this is fine for our setting.

It also replaces Keras' functions with our custom serialization
functions.
@vpratz vpratz force-pushed the fix-additional-metrics branch from 4c7a5a2 to 51dff0d Compare May 30, 2025 13:24
@vpratz
Copy link
Collaborator Author

vpratz commented May 30, 2025

@LarsKue @stefanradev93 @han-ol I encountered the problem that to pass metrics to the networks in their constructors, we would have to specify all get_config functions manually, as the metrics are not basic types.
I have now instead adapted the auto-config capabilities of keras.Layer to accept any object we can (de)serialize, and to use our serialization functions. This should make the auto-config functionality much more flexible. Take a look at the BaseLayer class for the implementation. It seems to me that this doesn't cause problems for now (hoping that all tests will pass).
Do you see any downsides with this approach? Or do you know something about the motivation of Keras to limit this functionality to basic types?

@vpratz vpratz changed the title [WIP] Fix: move additional metrics from approximator to networks Fix: move additional metrics from approximator to networks May 30, 2025
@stefanradev93
Copy link
Contributor

I would let @LarsKue chip in on this, as I am bit concerned about having more and more "custom" variants of basic keras containers (e.g., Sequential, Layer,...).

@vpratz
Copy link
Collaborator Author

vpratz commented Jun 1, 2025

Thanks for the comment, I can understand this, relying too much on hacking/modifying Keras internals has the downside that our code might become more fragile with respect to changes in Keras. I think the ability to pass non-basic types to our inference and summary network base classes would be nice to have, the behavior in this PR is one example of this.
The options that I see are:

  1. accept that we cannot pass non-basic arguments to the base classes, making implementation of some features more cumbersome (or impossible)
  2. forego the auto-config from offered by keras.Layer and explicitly implement get_config in all our networks, which works but requires some redundant work and can be error-prone
  3. a mechanism like the one implemented here, which offers more flexibility, but relies on Keras making not too drastic changes to the keras.Layer internals, which we cannot control

@LarsKue
Copy link
Contributor

LarsKue commented Jun 5, 2025

I would be on board if it was a lightweight wrapper, but it seems that there is a lot of copied code from keras, which we should avoid in general, imo.

@vpratz Can we solve this through monkey-patching somehow, like with the serialize and deserialize functions?

@stefanradev93
Copy link
Contributor

If this is not breaking in any downstream task and only adds to the existing functionality, we may consider doing a PR to keras or asking them to implement it (or why they decided not to)?

@vpratz
Copy link
Collaborator Author

vpratz commented Jun 6, 2025

@LarsKue Thanks for taking a look. I'll take a look at alternative ways to achieve this behavior.

@vpratz
Copy link
Collaborator Author

vpratz commented Jun 6, 2025

@LarsKue Thanks a lot, this was a really good pointer. I have now implemented a wrapper around the constructor that is applied inside the serializable decorator. In addition, the default from_config is replaced with a variant that deserializes the config again.

There is still some copied code, but I think it is now limited to the part that is required for the feature to behave as similar to the Keras implementation as possible, apart from the desired changes (we might not need the part regarding the dtype, but I'm not sure and would leave it in there for now).

With those changes, we could remove the from_config method from most classes. Is this a change I should make?

@stefanradev93 @LarsKue Please take another look and let me know what you think (the most relevant part is in bayesflow/utils/serialization.py

@LarsKue
Copy link
Contributor

LarsKue commented Jun 6, 2025

Thanks, this looks much better. Since this is a sensitive change, I think we should extensively test it before rolling it out. Otherwise, green light from my side!

@vpratz
Copy link
Collaborator Author

vpratz commented Jun 8, 2025

I forgot to make the same changes in the ModelComparisonApproximator, this is now fixed. I have changed the output that is supplied to its classifier metrics from predictions to probabilities. This way, it is compatible with the metrics offered by Keras.

@paul-buerkner
Copy link
Contributor

@vpratz and @LarsKue what is the status of this PR? Is this ready to be merged?

@vpratz
Copy link
Collaborator Author

vpratz commented Jun 30, 2025

@paul-buerkner I would like to get a review from @stefanradev93 before we merge this, as the changes affect multiple parts of the code. Also, I would wait until we have decided how we proceed on #525...

@vpratz vpratz requested a review from stefanradev93 June 30, 2025 14:20
@vpratz
Copy link
Collaborator Author

vpratz commented Jul 1, 2025

I see some weird behavior due to the inheritance, I will have to give this another look...

@vpratz vpratz marked this pull request as draft July 1, 2025 12:16
vpratz added 3 commits July 1, 2025 12:51
- get config has to be manually specified in the base classes, so that
  the config is stored even when to subclass overrides get_config
- to preserve the auto_config behavior, we have to use the
  `python_utils.default` decorator from, which marks them as default methods.
  This allows detecting if a subclass has overridden them. This is the
  same mechanism that Keras uses
- moved setting the `custom_metrics` parameter after the
  `super().__init__` calls, as the tracking is managed in setattr
- extended some tests to use metrics
@vpratz
Copy link
Collaborator Author

vpratz commented Jul 1, 2025

Ok, I think it works. The general mechanism for getting the config of a subclass MyLayer of keras.Layer is the following. For an instance layer = MyLayer(), calling layer.get_config is resolved in the "standard" Python fashion, so there are two cases:

a) The subclass overrides get_config

b) The subclass does not override get_config, which triggers the auto-config behavior

The default and is_default are two simple functions, which set/read the _is_default flag on the method:

def default(method):
    """Decorates a method to detect overrides in subclasses."""
    method._is_default = True
    return method


def is_default(method):
    """Check if a method is decorated with the `default` wrapper."""
    return getattr(method, "_is_default", False)

To adapt this setup for us, we want our base classes (InferenceNetwork, SummaryNetwork, PointInferenceNetwork) to do the same thing as keras.Layer: Provide a few entries for the config, and have the auto-config used if the sub-class does not override get_config.

The easiest way to achieve this is a get_config function decorated with @python_utils.default. Note that in the auto-config case, the subclass-config is provided by the last get_config call, so we cannot filter the base_config dictionaries in this setup.

@vpratz vpratz marked this pull request as ready for review July 1, 2025 14:22
@vpratz
Copy link
Collaborator Author

vpratz commented Jul 1, 2025

@stefanradev93 @LarsKue I have fixed the problem I found (see above) and cleaned up the code a bit. Could you please take another look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants