diff --git a/fulu/bnn_aug.py b/fulu/bnn_aug.py index 0834049..3cd04d7 100644 --- a/fulu/bnn_aug.py +++ b/fulu/bnn_aug.py @@ -36,7 +36,7 @@ def __init__( kl_weight=0.1, optimizer="Adam", debug=0, - device="cpu", + device="cpu", weight_decay=0, ): self.model = None @@ -142,7 +142,7 @@ def __init__( lr=0.01, kl_weight=0.0001, optimizer="Adam", - device="cpu", + device="cpu", weight_decay=0, ): super().__init__(passband2lam) @@ -199,8 +199,8 @@ def fit(self, t, flux, flux_err, passband): lr=self.lr, kl_weight=self.kl_weight, optimizer=self.optimizer, - device=self.device, - weight_decay=self.weight_decay + device=self.device, + weight_decay=self.weight_decay, ) self.reg.fit(X_ss, y_ss) return self diff --git a/fulu/mlp_reg_aug.py b/fulu/mlp_reg_aug.py index 4a624e3..2bd76fc 100644 --- a/fulu/mlp_reg_aug.py +++ b/fulu/mlp_reg_aug.py @@ -49,7 +49,7 @@ def __init__( activation="tanh", learning_rate_init=0.001, max_iter=90, - batch_size=1, + batch_size=1, alpha=0.0001, ): super().__init__(passband2lam) @@ -110,7 +110,7 @@ def fit(self, t, flux, flux_err, passband): activation=self.activation, learning_rate_init=self.learning_rate_init, max_iter=self.max_iter, - batch_size=self.batch_size, + batch_size=self.batch_size, alpha=self.alpha, ) self.reg.fit(X_ss, y_ss) diff --git a/fulu/nf_aug.py b/fulu/nf_aug.py index b65fdee..81863e5 100644 --- a/fulu/nf_aug.py +++ b/fulu/nf_aug.py @@ -109,7 +109,7 @@ def __init__( n_epochs=10, lr=0.0001, randomize_x=True, - device="cpu", + device="cpu", weight_decay=0, ): @@ -288,7 +288,7 @@ def fit(self, t, flux, flux_err, passband): n_epochs=self.n_epochs, lr=self.lr, randomize_x=True, - device=self.device, + device=self.device, weight_decay=self.weight_decay, ) self.reg.fit(X_ss, flux, flux_err) diff --git a/fulu/single_layer_aug.py b/fulu/single_layer_aug.py index 9ba7ae0..be06b91 100644 --- a/fulu/single_layer_aug.py +++ b/fulu/single_layer_aug.py @@ -142,7 +142,7 @@ def __init__( batch_size=500, lr=0.01, optimizer="Adam", - device="auto", + device="auto", weight_decay=0, ): super().__init__(passband2lam) @@ -204,8 +204,8 @@ def fit(self, t, flux, flux_err, passband): batch_size=self.batch_size, lr=self.lr, optimizer=self.optimizer, - device=self.device, - weight_decay=self.weight_decay + device=self.device, + weight_decay=self.weight_decay, ) self.reg.fit(X_ss, y_ss)