Skip to content

Commit 2cbf790

Browse files
committed
Update conformal MLP with kwargs
1 parent 652db10 commit 2cbf790

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

autoemulate/emulators/conformal.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import math
22

33
import torch
4-
from torch import Tensor
4+
from torch import Tensor, nn
5+
from torch.optim.lr_scheduler import LRScheduler
56

67
from autoemulate.core.device import TorchDeviceMixin
78
from autoemulate.core.types import DeviceLike, DistributionLike, TensorLike, TuneParams
@@ -142,7 +143,20 @@ def __init__(
142143
device: DeviceLike | None = None,
143144
alpha: float = 0.95,
144145
calibration_ratio: float = 0.2,
145-
**mlp_kwargs,
146+
activation_cls: type[nn.Module] = nn.ReLU,
147+
loss_fn_cls: type[nn.Module] = nn.MSELoss,
148+
epochs: int = 100,
149+
batch_size: int = 16,
150+
layer_dims: list[int] | None = None,
151+
weight_init: str = "default",
152+
scale: float = 1.0,
153+
bias_init: str = "default",
154+
dropout_prob: float | None = None,
155+
lr: float = 1e-2,
156+
params_size: int = 1,
157+
random_seed: int | None = None,
158+
scheduler_cls: type[LRScheduler] | None = None,
159+
scheduler_params: dict | None = None,
146160
):
147161
"""
148162
Initialize an ensemble of MLPs.
@@ -167,15 +181,28 @@ def __init__(
167181
mlp_kwargs: dict | None
168182
Additional keyword arguments for the MLP constructor.
169183
"""
170-
PyTorchBackend.__init__(self)
171-
self.mlp_kwargs = mlp_kwargs or {}
184+
nn.Module.__init__(self)
185+
172186
emulator = MLP(
173187
x,
174188
y,
175189
standardize_x=standardize_x,
176190
standardize_y=standardize_y,
177191
device=device,
178-
**self.mlp_kwargs,
192+
activation_cls=activation_cls,
193+
loss_fn_cls=loss_fn_cls,
194+
epochs=epochs,
195+
batch_size=batch_size,
196+
layer_dims=layer_dims,
197+
weight_init=weight_init,
198+
scale=scale,
199+
bias_init=bias_init,
200+
dropout_prob=dropout_prob,
201+
lr=lr,
202+
params_size=params_size,
203+
random_seed=random_seed,
204+
scheduler_cls=scheduler_cls,
205+
scheduler_params=scheduler_params,
179206
)
180207
Conformal.__init__(
181208
self,

0 commit comments

Comments
 (0)