11import math
22
33import torch
4- from torch import Tensor
4+ from torch import Tensor , nn
5+ from torch .optim .lr_scheduler import LRScheduler
56
67from autoemulate .core .device import TorchDeviceMixin
78from autoemulate .core .types import DeviceLike , DistributionLike , TensorLike , TuneParams
@@ -142,7 +143,20 @@ def __init__(
142143 device : DeviceLike | None = None ,
143144 alpha : float = 0.95 ,
144145 calibration_ratio : float = 0.2 ,
145- ** mlp_kwargs ,
146+ activation_cls : type [nn .Module ] = nn .ReLU ,
147+ loss_fn_cls : type [nn .Module ] = nn .MSELoss ,
148+ epochs : int = 100 ,
149+ batch_size : int = 16 ,
150+ layer_dims : list [int ] | None = None ,
151+ weight_init : str = "default" ,
152+ scale : float = 1.0 ,
153+ bias_init : str = "default" ,
154+ dropout_prob : float | None = None ,
155+ lr : float = 1e-2 ,
156+ params_size : int = 1 ,
157+ random_seed : int | None = None ,
158+ scheduler_cls : type [LRScheduler ] | None = None ,
159+ scheduler_params : dict | None = None ,
146160 ):
147161 """
148162 Initialize an ensemble of MLPs.
@@ -167,15 +181,28 @@ def __init__(
167181 mlp_kwargs: dict | None
168182 Additional keyword arguments for the MLP constructor.
169183 """
170- PyTorchBackend .__init__ (self )
171- self . mlp_kwargs = mlp_kwargs or {}
184+ nn . Module .__init__ (self )
185+
172186 emulator = MLP (
173187 x ,
174188 y ,
175189 standardize_x = standardize_x ,
176190 standardize_y = standardize_y ,
177191 device = device ,
178- ** self .mlp_kwargs ,
192+ activation_cls = activation_cls ,
193+ loss_fn_cls = loss_fn_cls ,
194+ epochs = epochs ,
195+ batch_size = batch_size ,
196+ layer_dims = layer_dims ,
197+ weight_init = weight_init ,
198+ scale = scale ,
199+ bias_init = bias_init ,
200+ dropout_prob = dropout_prob ,
201+ lr = lr ,
202+ params_size = params_size ,
203+ random_seed = random_seed ,
204+ scheduler_cls = scheduler_cls ,
205+ scheduler_params = scheduler_params ,
179206 )
180207 Conformal .__init__ (
181208 self ,
0 commit comments