Skip to content

implement compile_from_config and get_compile_config #442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@

return super().compile(*args, **kwargs)

def compile_from_config(self, config):
self.compile(**deserialize(config))
if hasattr(self, "optimizer") and self.built:
# Create optimizer variables.
self.optimizer.build(self.trainable_variables)

Check warning on line 111 in bayesflow/approximators/continuous_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/continuous_approximator.py#L111

Added line #L111 was not covered by tests

def compute_metrics(
self,
inference_variables: Tensor,
Expand Down Expand Up @@ -213,6 +219,16 @@

return base_config | serialize(config)

def get_compile_config(self):
base_config = super().get_compile_config() or {}

config = {
"inference_metrics": self.inference_network._metrics,
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
}

return base_config | serialize(config)

def estimate(
self,
conditions: Mapping[str, np.ndarray],
Expand Down
16 changes: 16 additions & 0 deletions bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@

return super().compile(*args, **kwargs)

def compile_from_config(self, config):
self.compile(**deserialize(config))
if hasattr(self, "optimizer") and self.built:

Check warning on line 123 in bayesflow/approximators/model_comparison_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/model_comparison_approximator.py#L122-L123

Added lines #L122 - L123 were not covered by tests
# Create optimizer variables.
self.optimizer.build(self.trainable_variables)

Check warning on line 125 in bayesflow/approximators/model_comparison_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/model_comparison_approximator.py#L125

Added line #L125 was not covered by tests

def compute_metrics(
self,
*,
Expand Down Expand Up @@ -262,6 +268,16 @@

return base_config | serialize(config)

def get_compile_config(self):
base_config = super().get_compile_config() or {}

Check warning on line 272 in bayesflow/approximators/model_comparison_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/model_comparison_approximator.py#L272

Added line #L272 was not covered by tests

config = {

Check warning on line 274 in bayesflow/approximators/model_comparison_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/model_comparison_approximator.py#L274

Added line #L274 was not covered by tests
"classifier_metrics": self.classifier_network._metrics,
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
}

return base_config | serialize(config)

Check warning on line 279 in bayesflow/approximators/model_comparison_approximator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/approximators/model_comparison_approximator.py#L279

Added line #L279 was not covered by tests

def predict(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/metrics/maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import keras

from bayesflow.utils.serialization import serializable
from .functional import maximum_mean_discrepancy


@serializable
class MaximumMeanDiscrepancy(keras.Metric):
def __init__(
self,
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/metrics/root_mean_squard_error.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from functools import partial
import keras


from bayesflow.utils.serialization import serializable
from .functional import root_mean_squared_error


@serializable
class RootMeanSquaredError(keras.metrics.MeanMetricWrapper):
def __init__(self, name="root_mean_squared_error", dtype=None, **kwargs):
fn = partial(root_mean_squared_error, **kwargs)
Expand Down