Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
82b3ab4
allow tensor in DiagonalNormal dimension
arrjon Sep 6, 2025
8fbf737
fix sum dims
arrjon Sep 7, 2025
5c27246
fix batch_shape for sample
arrjon Sep 7, 2025
c684bca
dims to tuple
arrjon Sep 7, 2025
0697634
first draft compositional
arrjon Sep 8, 2025
b8e849e
first draft compositional
arrjon Sep 8, 2025
a280af3
first draft compositional
arrjon Sep 8, 2025
b9faf31
first draft compositional
arrjon Sep 8, 2025
9b7eb16
fix shapes
arrjon Sep 8, 2025
e79aac1
fix shapes
arrjon Sep 8, 2025
8a80240
fix shapes
arrjon Sep 8, 2025
00fbc61
fix shapes
arrjon Sep 8, 2025
e6158e7
fix shapes
arrjon Sep 8, 2025
1ac39b2
fix shapes
arrjon Sep 8, 2025
9fd9cf8
add minibatch
arrjon Sep 8, 2025
830e929
add compositional_bridge
arrjon Sep 8, 2025
f97594b
fix mini batch randomness
arrjon Sep 8, 2025
7219a71
fix mini batch randomness
arrjon Sep 8, 2025
a10026a
fix mini batch randomness
arrjon Sep 8, 2025
457eb5d
add prior score
arrjon Sep 8, 2025
7de4736
add prior score
arrjon Sep 8, 2025
1ee0e78
add prior score draft
arrjon Sep 8, 2025
f71359b
add prior score draft
arrjon Sep 8, 2025
6210c07
add prior score draft
arrjon Sep 8, 2025
bcb9f60
add prior score draft
arrjon Sep 8, 2025
455f03c
fix dtype
arrjon Sep 8, 2025
89523a9
fix docstring
arrjon Sep 9, 2025
e55631d
fix batch_shape in sample
arrjon Sep 9, 2025
3eaff24
fix batch_shape for point approximator
arrjon Sep 9, 2025
5601d20
Merge branch 'normal_distribution_dimension' into compositional_sampl…
arrjon Sep 9, 2025
6b9671b
Merge branch 'dev' into compositional_sampling_diffusion
arrjon Sep 10, 2025
e97e375
fix docstring
arrjon Sep 10, 2025
caa2d67
fix float32
arrjon Sep 10, 2025
1ac9bff
reorganize
arrjon Sep 12, 2025
df23f89
add annealed_langevin
arrjon Sep 12, 2025
0a87694
fix annealed_langevin
arrjon Sep 12, 2025
64d4373
add predictor corrector sampling
arrjon Sep 12, 2025
5b42368
add predictor corrector sampling
arrjon Sep 12, 2025
9402941
add predictor corrector sampling
arrjon Sep 12, 2025
e0b3bd5
add predictor corrector sampling
arrjon Sep 12, 2025
89361f7
add predictor corrector sampling
arrjon Sep 12, 2025
5969bd3
robust mean scores
arrjon Sep 12, 2025
e983cf7
add some tests
arrjon Sep 12, 2025
eac9aaf
minor fixes
arrjon Sep 12, 2025
2a9b0e1
minor fixes
arrjon Sep 12, 2025
9a1ba32
add test for compute_prior_score_pre
arrjon Sep 12, 2025
93b59ba
fix order of prior scores
arrjon Sep 12, 2025
922040d
fix prior scores standardize
arrjon Sep 13, 2025
b2991d1
better standard values for compositional
arrjon Sep 13, 2025
d2a36a8
better compositional_bridge
arrjon Sep 13, 2025
0ff960f
fix integrate_kwargs
arrjon Sep 13, 2025
b2ef755
fix integrate_kwargs
arrjon Sep 13, 2025
ca7f3bd
fix kwargs in sample
arrjon Sep 16, 2025
09df093
Merge branch 'dev' into fix_sampling_method_kwargs
arrjon Sep 16, 2025
2c161c6
fix kwargs in set transformer
arrjon Sep 16, 2025
9d4c1a1
fix kwargs in set transformer
arrjon Sep 16, 2025
ea0659d
remove print
arrjon Sep 16, 2025
922412f
Merge branch 'fix_sampling_method_kwargs' into compositional_sampling…
arrjon Sep 22, 2025
9220816
new class for compositional diffusion
arrjon Sep 22, 2025
ee1c320
fix import
arrjon Sep 22, 2025
c977959
Merge branch 'dev' into compositional_sampling_diffusion
arrjon Sep 23, 2025
9fee1d4
Merge branch 'dev' into compositional_sampling_diffusion
arrjon Sep 23, 2025
d3f639d
Merge branch 'dev' into compositional_sampling_diffusion
arrjon Sep 25, 2025
7d15b49
Merge branch 'dev' into compositional_sampling_diffusion
arrjon Sep 25, 2025
e6513c1
add import
arrjon Sep 26, 2025
e87f9d1
fix mini_batch_size
arrjon Sep 26, 2025
983cb8d
fix mini_batch_size
arrjon Sep 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add test for compute_prior_score_pre
  • Loading branch information
arrjon committed Sep 12, 2025
commit 9a1ba32dc6e28b49b97cdb87ad0e41d8bbe518bd
13 changes: 5 additions & 8 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor:
_samples, forward=False, log_det_jac=True
)
else:
log_det_jac_standardize = 0
log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32")
_samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples})
adapted_samples, log_det_jac = self.adapter(
_samples, inverse=True, strict=False, log_det_jac=True, **kwargs
Expand All @@ -708,15 +708,12 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor:
for key in adapted_samples:
if isinstance(prior_score[key], np.ndarray):
prior_score[key] = prior_score[key].astype("float32")
if len(log_det_jac) > 0:
prior_score[key] += log_det_jac[key]
if len(log_det_jac) > 0 and key in log_det_jac:
prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key])

prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score)
# make a tensor
out = keras.ops.concatenate(
list(prior_score.values()), axis=-1
) # todo: assumes same order, might be incorrect
return out + expand_right_as(log_det_jac_standardize, out)
out = keras.ops.concatenate(list(prior_score.values()), axis=-1)
return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1)

# Test prior score function, useful for debugging
test = self.inference_network.base_distribution.sample((n_datasets, num_samples))
Expand Down
53 changes: 53 additions & 0 deletions tests/test_approximators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,56 @@ def approximator_with_summaries(request):
)
case _:
raise ValueError("Invalid param for approximator class.")


@pytest.fixture
def simple_log_simulator():
"""Create a simple simulator for testing."""
import numpy as np
from bayesflow.simulators import Simulator
from bayesflow.utils.decorators import allow_batch_size
from bayesflow.types import Shape, Tensor

class SimpleSimulator(Simulator):
"""Simple simulator that generates mean and scale parameters."""

@allow_batch_size
def sample(self, batch_shape: Shape) -> dict[str, Tensor]:
# Generate parameters in original space
loc = np.random.normal(0.0, 1.0, size=batch_shape + (2,)) # location parameters
scale = np.random.lognormal(0.0, 0.5, size=batch_shape + (2,)) # scale parameters > 0

# Generate some dummy conditions
conditions = np.random.normal(0.0, 1.0, size=batch_shape + (3,))

return dict(
loc=loc.astype("float32"), scale=scale.astype("float32"), conditions=conditions.astype("float32")
)

return SimpleSimulator()


@pytest.fixture
def transforming_adapter():
"""Create an adapter that applies log transformation to scale parameters."""
from bayesflow.adapters import Adapter

adapter = Adapter()
adapter.to_array()
adapter.convert_dtype("float64", "float32")

# Apply log transformation to scale parameters (to make them unbounded)
adapter.log(["scale"])

adapter.concatenate(["loc", "scale"], into="inference_variables")
adapter.concatenate(["conditions"], into="inference_conditions")
adapter.keep(["inference_variables", "inference_conditions"])
return adapter


@pytest.fixture
def diffusion_network():
"""Create a diffusion network for compositional sampling."""
from bayesflow.networks import DiffusionModel, MLP

return DiffusionModel(subnet=MLP(widths=[32, 32]))
109 changes: 109 additions & 0 deletions tests/test_approximators/test_compositional_prior_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Tests for compositional sampling and prior score computation with adapters."""

import numpy as np
import keras

from bayesflow import ContinuousApproximator
from bayesflow.utils import expand_right_as


def mock_prior_score_original_space(data_dict):
"""Mock prior score function that expects data in original (loc, scale) space."""
# The function receives data in the same format the compute_prior_score_pre creates
# after running the inverse adapter
loc = data_dict["loc"]
scale = data_dict["scale"]

# Simple prior: N(0,1) for loc, LogNormal(0,0.5) for scale
loc_score = -loc
scale_score = -1.0 / scale - np.log(scale) / (0.25 * scale)

return {"loc": loc_score, "scale": scale_score}


def test_prior_score_transforming_adapter(simple_log_simulator, transforming_adapter, diffusion_network):
"""Test that prior scores work correctly with transforming adapter (log transformation)."""

# Create approximator with transforming adapter
approximator = ContinuousApproximator(
adapter=transforming_adapter,
inference_network=diffusion_network,
)

# Generate test data and adapt it
data = simple_log_simulator.sample((2,))
adapted_data = transforming_adapter(data)

# Build approximator
approximator.build_from_data(adapted_data)

# Test compositional sampling
n_datasets, n_compositional = 3, 5
conditions = {"conditions": np.random.normal(0.0, 1.0, (n_datasets, n_compositional, 3)).astype("float32")}

# This should work - the compute_prior_score_pre function should handle the inverse transformation
samples = approximator.compositional_sample(
num_samples=10,
conditions=conditions,
compute_prior_score=mock_prior_score_original_space,
)

assert "loc" in samples
assert "scale" in samples
assert samples["loc"].shape == (n_datasets, 10, 2)
assert samples["scale"].shape == (n_datasets, 10, 2)


def test_prior_score_jacobian_correction(simple_log_simulator, transforming_adapter, diffusion_network):
"""Test that Jacobian correction is applied correctly in compute_prior_score_pre."""

# Create approximator with transforming adapter
approximator = ContinuousApproximator(
adapter=transforming_adapter, inference_network=diffusion_network, standardize=[]
)

# Build with dummy data
dummy_data_dict = simple_log_simulator.sample((1,))
adapted_dummy_data = transforming_adapter(dummy_data_dict)
approximator.build_from_data(adapted_dummy_data)

# Get the internal compute_prior_score_pre function
def get_compute_prior_score_pre():
def compute_prior_score_pre(_samples):
if "inference_variables" in approximator.standardize:
_samples, log_det_jac_standardize = approximator.standardize_layers["inference_variables"](
_samples, forward=False, log_det_jac=True
)
else:
log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32")

_samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples})
adapted_samples, log_det_jac = approximator.adapter(_samples, inverse=True, strict=False, log_det_jac=True)

prior_score = mock_prior_score_original_space(adapted_samples)
for key in adapted_samples:
if isinstance(prior_score[key], np.ndarray):
prior_score[key] = prior_score[key].astype("float32")
if len(log_det_jac) > 0 and key in log_det_jac:
prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key])

prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score)
out = keras.ops.concatenate(list(prior_score.values()), axis=-1)
return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1)

return compute_prior_score_pre

compute_prior_score_pre = get_compute_prior_score_pre()

# Test with a known transformation
y_samples = adapted_dummy_data["inference_variables"]
scores = compute_prior_score_pre(y_samples)
scores_np = keras.ops.convert_to_numpy(scores)[0] # Remove batch dimension

# With Jacobian correction: score_transformed = score_original - log|J|
old_scores = mock_prior_score_original_space(dummy_data_dict)
det_jac_scale = y_samples[0, 2:].sum()
expected_scores = np.array([old_scores["loc"][0], old_scores["scale"][0] - det_jac_scale]).flatten()

# Check that scores are reasonably close
np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6)
47 changes: 47 additions & 0 deletions tests/test_networks/test_diffusion_model/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import keras


@pytest.fixture()
Expand All @@ -21,3 +22,49 @@ def edm_noise_schedule():
)
def noise_schedule(request):
return request.getfixturevalue(request.param)


@pytest.fixture
def simple_diffusion_model():
"""Create a simple diffusion model for testing compositional sampling."""
from bayesflow.networks.diffusion_model import DiffusionModel
from bayesflow.networks import MLP

return DiffusionModel(
subnet=MLP(widths=[32, 32]),
noise_schedule="cosine",
prediction_type="noise",
loss_type="noise",
)


@pytest.fixture
def compositional_conditions():
"""Create test conditions for compositional sampling."""
batch_size = 2
n_compositional = 3
n_samples = 4
condition_dim = 5

return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim))


@pytest.fixture
def compositional_state():
"""Create test state for compositional sampling."""
batch_size = 2
n_samples = 4
param_dim = 3

return keras.random.normal((batch_size, n_samples, param_dim))


@pytest.fixture
def mock_prior_score():
"""Create a mock prior score function for testing."""

def prior_score_fn(theta):
# Simple quadratic prior: -0.5 * ||theta||^2
return -theta

return prior_score_fn
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,6 @@
import pytest


@pytest.fixture
def simple_diffusion_model():
"""Create a simple diffusion model for testing compositional sampling."""
from bayesflow.networks.diffusion_model import DiffusionModel
from bayesflow.networks import MLP

return DiffusionModel(
subnet=MLP(widths=[32, 32]),
noise_schedule="cosine",
prediction_type="noise",
loss_type="noise",
)


@pytest.fixture
def compositional_conditions():
"""Create test conditions for compositional sampling."""
batch_size = 2
n_compositional = 3
n_samples = 4
condition_dim = 5

return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim))


@pytest.fixture
def compositional_state():
"""Create test state for compositional sampling."""
batch_size = 2
n_samples = 4
param_dim = 3

return keras.random.normal((batch_size, n_samples, param_dim))


@pytest.fixture
def mock_prior_score():
"""Create a mock prior score function for testing."""

def prior_score_fn(theta):
# Simple quadratic prior: -0.5 * ||theta||^2
return -theta

return prior_score_fn


def test_compositional_score_shape(
simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score
):
Expand Down
Loading