Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b392472
Add validation_data to Emulator API
sgreenbury Nov 4, 2025
c630317
Add output_to_tensor conversion, update usage
sgreenbury Nov 4, 2025
87a2a0b
Add conformal module
sgreenbury Nov 4, 2025
88153bf
Add initial conformal impl and test
sgreenbury Nov 4, 2025
ef684d3
Add ConformalMLP to registry and re-export
sgreenbury Nov 4, 2025
e76ed72
Add support for case when no cal data provided
sgreenbury Nov 4, 2025
6e99ffb
Update test_grads for Conformal emulators
sgreenbury Nov 4, 2025
31cbdd3
Update LightGBM with validation data
sgreenbury Nov 4, 2025
340d1e7
Update conformal MLP with kwargs
sgreenbury Nov 4, 2025
905a6d6
Update docstring
sgreenbury Nov 4, 2025
051ed55
Add conformal quantile regresssion
sgreenbury Nov 5, 2025
3c609fd
Add n_samples to APIs
sgreenbury Nov 5, 2025
e1db3c7
Add test passing validation data
sgreenbury Nov 5, 2025
c7146d9
Add comment
sgreenbury Nov 7, 2025
76ec652
Remove obsolete return type
sgreenbury Nov 7, 2025
0be09d5
Remove separate args
sgreenbury Nov 7, 2025
afefe7e
Merge remote-tracking branch 'origin/main' into 848-conformal-prediction
sgreenbury Nov 7, 2025
cbc9821
Update docstring and kwarg order
sgreenbury Nov 7, 2025
483309f
Fix indentation of dcostring
sgreenbury Nov 7, 2025
edbf1d2
Fix docstrings
sgreenbury Nov 7, 2025
d1bc291
Add comment
sgreenbury Nov 7, 2025
52fa84c
Fix docstring
sgreenbury Nov 7, 2025
dbd9f2d
Revise type hints
sgreenbury Nov 7, 2025
ac897c8
Fix assertion in test
sgreenbury Nov 7, 2025
0728857
Add with_grad to docstring
sgreenbury Nov 13, 2025
bd9b57c
Update
sgreenbury Nov 13, 2025
ba2ce49
Revise return docstring
sgreenbury Nov 13, 2025
183dc76
Update method str to "constant"
sgreenbury Nov 13, 2025
5cf4f6e
Add option to customise distribution
sgreenbury Nov 13, 2025
65c326a
Fix error message
sgreenbury Nov 13, 2025
e0cf73c
Extend docstring and add reference
sgreenbury Nov 13, 2025
dcbfe72
Update docstring to explain adding new methods
sgreenbury Nov 13, 2025
b87c96d
Add conformal MLP subclass factory and subclasses
sgreenbury Nov 13, 2025
0d6bd9f
Fix test
sgreenbury Nov 13, 2025
012776d
Fix subclassing and defaults
sgreenbury Nov 14, 2025
7e786b5
Ensure bounds ordering is valid
sgreenbury Nov 17, 2025
f5c5f2d
Merge remote-tracking branch 'origin/main' into 848-conformal-prediction
sgreenbury Nov 18, 2025
6489aaa
Update tuning to support MetricParams
sgreenbury Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions autoemulate/core/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
log_level: str = "progress_bar",
tuning_metric: str | Metric = "r2",
evaluation_metrics: list[str | Metric] | None = None,
n_samples: int = 1000,
):
"""
Initialize the AutoEmulate class.
Expand Down Expand Up @@ -130,6 +131,9 @@ def __init__(
Each entry can be a string shortcut or a MetricConfig object.
IMPORTANT: The first metric in the list is used to
determine the best model.
n_samples: int
Number of samples to generate to predict mean when emulator does not have a
mean directly available. Defaults to 1000.
"""
Results.__init__(self)
self.random_seed = random_seed
Expand Down Expand Up @@ -187,6 +191,7 @@ def __init__(
# Set up logger and ModelSerialiser for saving models
self.logger, self.progress_bar = get_configured_logger(log_level)
self.model_serialiser = ModelSerialiser(self.logger)
self.n_samples = n_samples

# Run compare
self.compare()
Expand Down Expand Up @@ -417,6 +422,9 @@ def compare(self):
n_splits=self.n_splits,
shuffle=self.shuffle,
transformed_emulator_params=self.transformed_emulator_params,
metric_params=MetricParams(
n_samples=self.n_samples
),
)
mean_scores = [
np.mean(score).item() for score in scores
Expand Down Expand Up @@ -484,7 +492,9 @@ def compare(self):
n_bootstraps=self.n_bootstraps,
device=self.device,
metrics=self.evaluation_metrics,
metric_params=MetricParams(y_train=train_val_y),
metric_params=MetricParams(
n_samples=self.n_samples, y_train=train_val_y
),
)
test_metrics = bootstrap(
transformed_emulator,
Expand All @@ -493,7 +503,9 @@ def compare(self):
n_bootstraps=self.n_bootstraps,
device=self.device,
metrics=self.evaluation_metrics,
metric_params=MetricParams(y_train=train_val_y),
metric_params=MetricParams(
n_samples=self.n_samples, y_train=train_val_y
),
)

# Log all test metrics from test_metrics dictionary
Expand Down
13 changes: 12 additions & 1 deletion autoemulate/core/model_selection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import logging
from dataclasses import replace

import torch
from sklearn.model_selection import BaseCrossValidator
Expand Down Expand Up @@ -61,6 +62,7 @@ def cross_validate(
device: DeviceLike = "cpu",
random_seed: int | None = None,
metrics: list[Metric] | None = None,
metric_params: MetricParams | None = None,
):
"""
Cross validate model performance using the given `cv` strategy.
Expand All @@ -85,6 +87,8 @@ def cross_validate(
Optional random seed for reproducibility.
metrics: list[TorchMetrics] | None
List of metrics to compute. If None, uses r2 and rmse.
metric_params: MetricParams | None
Additional parameters to pass to the metrics. Defaults to None.

Returns
-------
Expand All @@ -94,6 +98,7 @@ def cross_validate(
transformed_emulator_params = transformed_emulator_params or {}
x_transforms = x_transforms or []
y_transforms = y_transforms or []
metric_params = metric_params or MetricParams()

# Setup metrics
if metrics is None:
Expand Down Expand Up @@ -143,7 +148,13 @@ def cross_validate(
# compute and save results
y_pred = transformed_emulator.predict(x_val)
for metric in metrics:
score = evaluate(y_pred, y_val, metric)
score = evaluate(
# Update metric_params with y_train in case required by metric
y_pred,
y_val,
metric,
metric_params=replace(metric_params, y_train=y),
)
cv_results[metric.name].append(score)
return cv_results

Expand Down
6 changes: 5 additions & 1 deletion autoemulate/core/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.distributions import Transform

from autoemulate.core.device import TorchDeviceMixin
from autoemulate.core.metrics import Metric, get_metric
from autoemulate.core.metrics import Metric, MetricParams, get_metric
from autoemulate.core.model_selection import cross_validate
from autoemulate.core.types import (
DeviceLike,
Expand Down Expand Up @@ -74,6 +74,7 @@ def run(
n_splits: int = 5,
seed: int | None = None,
shuffle: bool = True,
metric_params: MetricParams | None = None,
) -> tuple[list[list[float]], list[ModelParams]]:
"""
Run randomised hyperparameter search for a given model.
Expand All @@ -97,6 +98,8 @@ def run(
shuffle: bool
Whether to shuffle data before splitting into cross validation folds.
Defaults to True.
metric_params: MetricParams | None
Additional parameters to pass to the metrics. Defaults to None.

Returns
-------
Expand Down Expand Up @@ -130,6 +133,7 @@ def run(
device=self.device,
random_seed=None,
metrics=[self.tuning_metric],
metric_params=metric_params,
)

# Reset retries following a successful cross_validation call
Expand Down
36 changes: 36 additions & 0 deletions autoemulate/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,42 @@ def _denormalize(
) -> TensorLike:
return (x * x_std) + x_mean

def output_to_tensor(
self,
output: OutputLike,
n_samples: int = 1000,
with_grad: bool = False,
) -> torch.Tensor:
"""Convert an output to a tensor (returns the mean if output is a distribution).
Parameters
----------
output: OutputLike
The output to convert to a tensor.
n_samples: int
Number of samples to draw from the distribution. Defaults to 1000.
with_grad: bool
Whether to enable gradient calculation. Defaults to False.
Returns
-------
TensorLike
Tensor of shape `(n_batch, n_targets)` as input or the mean of the output if
output is a distribution.
"""
if isinstance(output, TensorLike):
return output
try:
return output.mean
except Exception:
# Use sampling to get a mean if mean property not available
samples = (
output.rsample(torch.Size([n_samples]))
if with_grad
else output.sample(torch.Size([n_samples]))
)
return samples.mean(dim=0)


def set_random_seed(seed: int = 42, deterministic: bool = True):
"""
Expand Down
2 changes: 2 additions & 0 deletions autoemulate/emulators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base import Emulator
from .conformal import ConformalMLP
from .ensemble import EnsembleMLP, EnsembleMLPDropout
from .gaussian_process.exact import (
GaussianProcessCorrelatedMatern32,
Expand Down Expand Up @@ -26,6 +27,7 @@

__all__ = [
"MLP",
"ConformalMLP",
"Emulator",
"EnsembleMLP",
"EnsembleMLPDropout",
Expand Down
43 changes: 26 additions & 17 deletions autoemulate/emulators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,19 @@ class Emulator(ABC, ValidationMixin, ConversionMixin, TorchDeviceMixin):
supports_uq: bool = False

@abstractmethod
def _fit(self, x: TensorLike, y: TensorLike): ...
def _fit(
self,
x: TensorLike,
y: TensorLike,
validation_data: tuple[TensorLike, TensorLike] | None = None,
): ...

def fit(self, x: TensorLike, y: TensorLike):
def fit(
self,
x: TensorLike,
y: TensorLike,
validation_data: tuple[TensorLike, TensorLike] | None = None,
):
"""Fit the emulator to the provided data."""
# Ensure x and y are tensors and 2D
x, y = self._convert_to_tensors(x, y)
Expand All @@ -58,7 +68,7 @@ def fit(self, x: TensorLike, y: TensorLike):
y = self.y_transform(y) if self.y_transform is not None else y

# Fit emulator
self._fit(x, y)
self._fit(x, y, validation_data)
self.is_fitted_ = True

@abstractmethod
Expand Down Expand Up @@ -152,18 +162,7 @@ def predict_mean(
"""
x = self._ensure_with_grad(x, with_grad)
y_pred = self._predict(x, with_grad)
if isinstance(y_pred, TensorLike):
return y_pred
try:
return y_pred.mean
except Exception:
# Use sampling to get a mean if mean property not available
samples = (
y_pred.rsample(torch.Size([n_samples]))
if with_grad
else y_pred.sample(torch.Size([n_samples]))
)
return samples.mean(dim=0)
return self.output_to_tensor(y_pred, n_samples)

def predict_mean_and_variance(
self, x: TensorLike, with_grad: bool = False, n_samples: int = 100
Expand Down Expand Up @@ -559,7 +558,12 @@ def loss_func(self, y_pred, y_true):
"""Loss function to be used for training the model."""
return nn.MSELoss()(y_pred, y_true)

def _fit(self, x: TensorLike, y: TensorLike):
def _fit(
self,
x: TensorLike,
y: TensorLike,
validation_data: tuple[TensorLike, TensorLike] | None = None, # noqa: ARG002
):
"""
Train a PyTorchBackend model.

Expand Down Expand Up @@ -683,7 +687,12 @@ class SklearnBackend(DeterministicEmulator):
def _model_specific_check(self, x: NumpyLike, y: NumpyLike):
_, _ = x, y

def _fit(self, x: TensorLike, y: TensorLike):
def _fit(
self,
x: TensorLike,
y: TensorLike,
validation_data: tuple[TensorLike, TensorLike] | None = None, # noqa: ARG002
):
if self.normalize_y:
y, y_mean, y_std = self._normalize(y)
self.y_mean = y_mean
Expand Down
Loading