Skip to content

Fix regression algorithms to give correct output dimensions #1335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions autosklearn/pipeline/components/regression/adaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, n_estimators, learning_rate, loss, max_depth, random_state=No
self.max_depth = max_depth
self.estimator = None

def fit(self, X, Y):
def fit(self, X, y):
import sklearn.ensemble
import sklearn.tree

Expand All @@ -32,7 +32,11 @@ def fit(self, X, Y):
loss=self.loss,
random_state=self.random_state
)
self.estimator.fit(X, Y)

if y.ndim == 2 and y.shape[1] == 1:
y = y.flatten()

self.estimator.fit(X, y)
return self

def predict(self, X):
Expand Down
37 changes: 21 additions & 16 deletions autosklearn/pipeline/components/regression/ard_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def __init__(self, n_iter, tol, alpha_1, alpha_2, lambda_1, lambda_2,
self.threshold_lambda = threshold_lambda
self.fit_intercept = fit_intercept

def fit(self, X, Y):
import sklearn.linear_model
def fit(self, X, y):
from sklearn.linear_model import ARDRegression

self.n_iter = int(self.n_iter)
self.tol = float(self.tol)
Expand All @@ -34,20 +34,25 @@ def fit(self, X, Y):
self.threshold_lambda = float(self.threshold_lambda)
self.fit_intercept = check_for_bool(self.fit_intercept)

self.estimator = sklearn.linear_model.\
ARDRegression(n_iter=self.n_iter,
tol=self.tol,
alpha_1=self.alpha_1,
alpha_2=self.alpha_2,
lambda_1=self.lambda_1,
lambda_2=self.lambda_2,
compute_score=False,
threshold_lambda=self.threshold_lambda,
fit_intercept=True,
normalize=False,
copy_X=False,
verbose=False)
self.estimator.fit(X, Y)
self.estimator = ARDRegression(
n_iter=self.n_iter,
tol=self.tol,
alpha_1=self.alpha_1,
alpha_2=self.alpha_2,
lambda_1=self.lambda_1,
lambda_2=self.lambda_2,
compute_score=False,
threshold_lambda=self.threshold_lambda,
fit_intercept=True,
normalize=False,
copy_X=False,
verbose=False
)

if y.ndim == 2 and y.shape[1] == 1:
y = y.flatten()

self.estimator.fit(X, y)
return self

def predict(self, X):
Expand Down
4 changes: 4 additions & 0 deletions autosklearn/pipeline/components/regression/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def fit(self, X, y, sample_weight=None):
min_weight_fraction_leaf=self.min_weight_fraction_leaf,
min_impurity_decrease=self.min_impurity_decrease,
random_state=self.random_state)

if y.ndim == 2 and y.shape[1] == 1:
y = y.flatten()

self.estimator.fit(X, y, sample_weight=sample_weight)
return self

Expand Down
5 changes: 4 additions & 1 deletion autosklearn/pipeline/components/regression/extra_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def iterative_fit(self, X, y, n_iter=1, refit=False):
self.estimator.n_estimators = min(self.estimator.n_estimators,
self.n_estimators)

self.estimator.fit(X, y,)
if y.ndim == 2 and y.shape[1] == 1:
y = y.flatten()

self.estimator.fit(X, y)

return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def __init__(self, alpha, thetaL, thetaU, random_state=None):
self.thetaU = thetaU
self.random_state = random_state
self.estimator = None
self.scaler = None

def fit(self, X, y):
import sklearn.gaussian_process
Expand All @@ -38,6 +37,9 @@ def fit(self, X, y):
normalize_y=True
)

if y.ndim == 2 and y.shape[1] == 1:
y = y.flatten()

self.estimator.fit(X, y)

return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def get_current_iter(self):
return self.estimator.n_iter_

def iterative_fit(self, X, y, n_iter=2, refit=False):

"""
Set n_iter=2 for the same reason as for SGD
"""
""" Set n_iter=2 for the same reason as for SGD """
import sklearn.ensemble
from sklearn.experimental import enable_hist_gradient_boosting # noqa

Expand Down Expand Up @@ -112,6 +109,9 @@ def iterative_fit(self, X, y, n_iter=2, refit=False):
self.estimator.max_iter = min(self.estimator.max_iter,
self.max_iter)

if y.ndim == 2 and y.shape[1] == 1:
y = y.flatten()

self.estimator.fit(X, y)

if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, n_neighbors, weights, p, random_state=None):
self.p = p
self.random_state = random_state

def fit(self, X, Y):
def fit(self, X, y):
import sklearn.neighbors

self.n_neighbors = int(self.n_neighbors)
Expand All @@ -24,7 +24,11 @@ def fit(self, X, Y):
n_neighbors=self.n_neighbors,
weights=self.weights,
p=self.p)
self.estimator.fit(X, Y)

if y.ndim == 2 and y.shape[1] == 1:
y = y.flatten()

self.estimator.fit(X, y)
return self

def predict(self, X):
Expand Down
8 changes: 6 additions & 2 deletions autosklearn/pipeline/components/regression/liblinear_svr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, loss, epsilon, dual, tol, C, fit_intercept,
self.random_state = random_state
self.estimator = None

def fit(self, X, Y):
def fit(self, X, y):
import sklearn.svm

self.C = float(self.C)
Expand All @@ -42,7 +42,11 @@ def fit(self, X, Y):
fit_intercept=self.fit_intercept,
intercept_scaling=self.intercept_scaling,
random_state=self.random_state)
self.estimator.fit(X, Y)

if y.ndim == 2 and y.shape[1] == 1:
y = y.flatten()

self.estimator.fit(X, y)
return self

def predict(self, X):
Expand Down
30 changes: 23 additions & 7 deletions autosklearn/pipeline/components/regression/libsvm_svr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ConfigSpace.hyperparameters import UniformFloatHyperparameter, \
UniformIntegerHyperparameter, CategoricalHyperparameter, \
UnParametrizedHyperparameter

from autosklearn.pipeline.components.base import AutoSklearnRegressionAlgorithm
from autosklearn.pipeline.constants import DENSE, UNSIGNED_DATA, PREDICTIONS, SPARSE
from autosklearn.util.common import check_for_bool, check_none
Expand All @@ -29,7 +28,7 @@ def __init__(self, kernel, C, epsilon, tol, shrinking, gamma=0.1,
self.random_state = random_state
self.estimator = None

def fit(self, X, Y):
def fit(self, X, y):
import sklearn.svm

# Calculate the size of the kernel cache (in MB) for sklearn's LibSVM. The cache size is
Expand Down Expand Up @@ -88,18 +87,35 @@ def fit(self, X, Y):
)
self.scaler = sklearn.preprocessing.StandardScaler(copy=True)

self.scaler.fit(Y.reshape((-1, 1)))
Y_scaled = self.scaler.transform(Y.reshape((-1, 1))).ravel()
self.estimator.fit(X, Y_scaled)
# Convert y to be at least 2d for the scaler
# [1,1,1] -> [[1], [1], [1]]
if y.ndim == 1:
y = y.reshape((-1, 1))

y_scaled = self.scaler.fit_transform(y)

# Flatten: [[0], [0], [0]] -> [0, 0, 0]
if y_scaled.ndim == 2 and y_scaled.shape[1] == 1:
y_scaled = y_scaled.flatten()

self.estimator.fit(X, y_scaled)

return self

def predict(self, X):
if self.estimator is None:
raise NotImplementedError
if self.scaler is None:
raise NotImplementedError
Y_pred = self.estimator.predict(X)
return self.scaler.inverse_transform(Y_pred)
y_pred = self.estimator.predict(X)

inverse = self.scaler.inverse_transform(y_pred)

# Flatten: [[0], [0], [0]] -> [0, 0, 0]
if inverse.ndim == 2 and inverse.shape[1] == 1:
inverse = inverse.flatten()

return inverse

@staticmethod
def get_properties(dataset_properties=None):
Expand Down
42 changes: 35 additions & 7 deletions autosklearn/pipeline/components/regression/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,36 @@ def iterative_fit(self, X, y, n_iter=2, refit=False):
# max_fun=self.max_fun
)
self.scaler = sklearn.preprocessing.StandardScaler(copy=True)
self.scaler.fit(y.reshape((-1, 1)))

# Convert y to be at least 2d for the StandardScaler
# [1,1,1] -> [[1], [1], [1]]
if y.ndim == 1:
y = y.reshape((-1, 1))

self.scaler.fit(y)
else:
new_max_iter = min(self.max_iter - self.estimator.n_iter_, n_iter)
self.estimator.max_iter = new_max_iter

Y_scaled = self.scaler.transform(y.reshape((-1, 1))).ravel()
self.estimator.fit(X, Y_scaled)
if self.estimator.n_iter_ >= self.max_iter or \
self.estimator._no_improvement_count > self.n_iter_no_change:
# Convert y to be at least 2d for the scaler
# [1,1,1] -> [[1], [1], [1]]
if y.ndim == 1:
y = y.reshape((-1, 1))

y_scaled = self.scaler.transform(y)

# Flatten: [[0], [0], [0]] -> [0, 0, 0]
if y_scaled.ndim == 2 and y_scaled.shape[1] == 1:
y_scaled = y_scaled.flatten()

self.estimator.fit(X, y_scaled)

if (
self.estimator.n_iter_ >= self.max_iter
or self.estimator._no_improvement_count > self.n_iter_no_change
):
self._fully_fit = True

return self

def configuration_fully_fitted(self):
Expand All @@ -160,8 +180,16 @@ def configuration_fully_fitted(self):
def predict(self, X):
if self.estimator is None:
raise NotImplementedError
Y_pred = self.estimator.predict(X)
return self.scaler.inverse_transform(Y_pred)

y_pred = self.estimator.predict(X)

inverse = self.scaler.inverse_transform(y_pred)

# Flatten: [[0], [0], [0]] -> [0, 0, 0]
if inverse.ndim == 2 and inverse.shape[1] == 1:
inverse = inverse.flatten()

return inverse

@staticmethod
def get_properties(dataset_properties=None):
Expand Down
3 changes: 3 additions & 0 deletions autosklearn/pipeline/components/regression/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def iterative_fit(self, X, y, n_iter=1, refit=False):
self.estimator.n_estimators = min(self.estimator.n_estimators,
self.n_estimators)

if y.ndim == 2 and y.shape[1] == 1:
y = y.flatten()

self.estimator.fit(X, y)
return self

Expand Down
29 changes: 24 additions & 5 deletions autosklearn/pipeline/components/regression/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,36 @@ def iterative_fit(self, X, y, n_iter=2, refit=False):
warm_start=True)

self.scaler = sklearn.preprocessing.StandardScaler(copy=True)
self.scaler.fit(y.reshape((-1, 1)))
Y_scaled = self.scaler.transform(y.reshape((-1, 1))).ravel()
self.estimator.fit(X, Y_scaled)

if y.ndim == 1:
y = y.reshape((-1, 1))

y_scaled = self.scaler.fit_transform(y)

# Flatten: [[0], [0], [0]] -> [0, 0, 0]
if y_scaled.ndim == 2 and y_scaled.shape[1] == 1:
y_scaled = y_scaled.flatten()

self.estimator.fit(X, y_scaled)
self.n_iter_ = self.estimator.n_iter_
else:
self.estimator.max_iter += n_iter
self.estimator.max_iter = min(self.estimator.max_iter, self.max_iter)
Y_scaled = self.scaler.transform(y.reshape((-1, 1))).ravel()

# Convert y to be at least 2d for the scaler
# [1,1,1] -> [[1], [1], [1]]
if y.ndim == 1:
y = y.reshape((-1, 1))

y_scaled = self.scaler.transform(y)

# Flatten: [[0], [0], [0]] -> [0, 0, 0]
if y_scaled.ndim == 2 and y_scaled.shape[1] == 1:
y_scaled = y_scaled.flatten()

self.estimator._validate_params()
self.estimator._partial_fit(
X, Y_scaled,
X, y_scaled,
alpha=self.estimator.alpha,
C=1.0,
loss=self.estimator.loss,
Expand Down
Loading