Skip to content

Commit

Permalink
tests sample_weights linearreg, ridge
Browse files Browse the repository at this point in the history
  • Loading branch information
giorgiop committed Nov 10, 2015
1 parent 04ca448 commit 1d16ec4
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 64 deletions.
90 changes: 55 additions & 35 deletions sklearn/linear_model/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@

import numpy as np
from scipy import sparse
from scipy import linalg

from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_equal

from sklearn.linear_model.base import LinearRegression
from sklearn.linear_model.base import center_data, sparse_center_data, _rescale_data
from sklearn.linear_model.base import center_data
from sklearn.linear_model.base import sparse_center_data
from sklearn.linear_model.base import _rescale_data
from sklearn.utils import check_random_state
from sklearn.utils.testing import assert_greater
from sklearn.datasets.samples_generator import make_sparse_uncorrelated
Expand All @@ -23,48 +27,64 @@ def test_linear_regression():
X = [[1], [2]]
Y = [1, 2]

clf = LinearRegression()
clf.fit(X, Y)
reg = LinearRegression()
reg.fit(X, Y)

assert_array_almost_equal(clf.coef_, [1])
assert_array_almost_equal(clf.intercept_, [0])
assert_array_almost_equal(clf.predict(X), [1, 2])
assert_array_almost_equal(reg.coef_, [1])
assert_array_almost_equal(reg.intercept_, [0])
assert_array_almost_equal(reg.predict(X), [1, 2])

# test it also for degenerate input
X = [[1]]
Y = [0]

clf = LinearRegression()
clf.fit(X, Y)
assert_array_almost_equal(clf.coef_, [0])
assert_array_almost_equal(clf.intercept_, [0])
assert_array_almost_equal(clf.predict(X), [0])
reg = LinearRegression()
reg.fit(X, Y)
assert_array_almost_equal(reg.coef_, [0])
assert_array_almost_equal(reg.intercept_, [0])
assert_array_almost_equal(reg.predict(X), [0])


def test_linear_regression_sample_weights():
# TODO: loop over sparse data as well

rng = np.random.RandomState(0)

for n_samples, n_features in ((6, 5), (5, 10)):
# It would not work with under-determined systems
for n_samples, n_features in ((6, 5), ):

y = rng.randn(n_samples)
X = rng.randn(n_samples, n_features)
sample_weight = 1.0 + rng.rand(n_samples)

clf = LinearRegression()
clf.fit(X, y, sample_weight)
coefs1 = clf.coef_
for intercept in (True, False):

# LinearRegression with explicit sample_weight
reg = LinearRegression(fit_intercept=intercept)
reg.fit(X, y, sample_weight=sample_weight)
coefs1 = reg.coef_
inter1 = reg.intercept_

assert_equal(reg.coef_.shape, (X.shape[1], )) # sanity checks
assert_greater(reg.score(X, y), 0.5)

assert_equal(clf.coef_.shape, (X.shape[1], ))
assert_greater(clf.score(X, y), 0.9)
assert_array_almost_equal(clf.predict(X), y)
# Closed form of the weighted least square
# theta = (X^T W X)^(-1) * X^T W y
W = np.diag(sample_weight)
if intercept is False:
X_aug = X
else:
dummy_column = np.ones(shape=(n_samples, 1))
X_aug = np.concatenate((dummy_column, X), axis=1)

# Sample weight can be implemented via a simple rescaling
# for the square loss.
scaled_y = y * np.sqrt(sample_weight)
scaled_X = X * np.sqrt(sample_weight)[:, np.newaxis]
clf.fit(X, y)
coefs2 = clf.coef_
coefs2 = linalg.solve(X_aug.T.dot(W).dot(X_aug),
X_aug.T.dot(W).dot(y))

assert_array_almost_equal(coefs1, coefs2)
if intercept is False:
assert_array_almost_equal(coefs1, coefs2)
else:
assert_array_almost_equal(coefs1, coefs2[1:])
assert_almost_equal(inter1, coefs2[0])


def test_raises_value_error_if_sample_weights_greater_than_1d():
Expand All @@ -82,12 +102,12 @@ def test_raises_value_error_if_sample_weights_greater_than_1d():
sample_weights_OK_1 = 1.
sample_weights_OK_2 = 2.

clf = LinearRegression()
reg = LinearRegression()

# make sure the "OK" sample weights actually work
clf.fit(X, y, sample_weights_OK)
clf.fit(X, y, sample_weights_OK_1)
clf.fit(X, y, sample_weights_OK_2)
reg.fit(X, y, sample_weights_OK)
reg.fit(X, y, sample_weights_OK_1)
reg.fit(X, y, sample_weights_OK_2)


def test_fit_intercept():
Expand Down Expand Up @@ -135,12 +155,12 @@ def test_linear_regression_multiple_outcome(random_state=0):
Y = np.vstack((y, y)).T
n_features = X.shape[1]

clf = LinearRegression(fit_intercept=True)
clf.fit((X), Y)
assert_equal(clf.coef_.shape, (2, n_features))
Y_pred = clf.predict(X)
clf.fit(X, y)
y_pred = clf.predict(X)
reg = LinearRegression(fit_intercept=True)
reg.fit((X), Y)
assert_equal(reg.coef_.shape, (2, n_features))
Y_pred = reg.predict(X)
reg.fit(X, y)
y_pred = reg.predict(X)
assert_array_almost_equal(np.vstack((y_pred, y_pred)).T, Y_pred, decimal=3)


Expand Down
76 changes: 47 additions & 29 deletions sklearn/linear_model/tests/test_ridge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import scipy.sparse as sp
from scipy import linalg
from itertools import product

from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_almost_equal
Expand Down Expand Up @@ -112,20 +113,21 @@ def test_ridge_singular():
assert_greater(ridge.score(X, y), 0.9)


def test_ridge_sample_weights():
def test_ridge_regression_sample_weights():
rng = np.random.RandomState(0)

for solver in ("cholesky", ):
for n_samples, n_features in ((6, 5), (5, 10)):
for alpha in (1.0, 1e-2):
y = rng.randn(n_samples)
X = rng.randn(n_samples, n_features)
sample_weight = 1 + rng.rand(n_samples)
sample_weight = 1.0 + rng.rand(n_samples)

coefs = ridge_regression(X, y,
alpha=alpha,
sample_weight=sample_weight,
solver=solver)

# Sample weight can be implemented via a simple rescaling
# for the square loss.
coefs2 = ridge_regression(
Expand All @@ -134,32 +136,48 @@ def test_ridge_sample_weights():
alpha=alpha, solver=solver)
assert_array_almost_equal(coefs, coefs2)

# Test for fit_intercept = True
est = Ridge(alpha=alpha, solver=solver)
est.fit(X, y, sample_weight=sample_weight)

# Check using Newton's Method
# Quadratic function should be solved in a single step.
# Initialize
sample_weight = np.sqrt(sample_weight)
X_weighted = sample_weight[:, np.newaxis] * (
np.column_stack((np.ones(n_samples), X)))
y_weighted = y * sample_weight

# Gradient is (X*coef-y)*X + alpha*coef_[1:]
# Remove coef since it is initialized to zero.
grad = -np.dot(y_weighted, X_weighted)

# Hessian is (X.T*X) + alpha*I except that the first
# diagonal element should be zero, since there is no
# penalization of intercept.
diag = alpha * np.ones(n_features + 1)
diag[0] = 0.
hess = np.dot(X_weighted.T, X_weighted)
hess.flat[::n_features + 2] += diag
coef_ = - np.dot(linalg.inv(hess), grad)
assert_almost_equal(coef_[0], est.intercept_)
assert_array_almost_equal(coef_[1:], est.coef_)

def test_ridge_sample_weights():
# TODO: loop over sparse data as well

rng = np.random.RandomState(0)
param_grid = product((1.0, 1e-2), (True, False),
('svd', 'cholesky', 'lsqr', 'sparse_cg'))

for n_samples, n_features in ((6, 5), (5, 10)):

y = rng.randn(n_samples)
X = rng.randn(n_samples, n_features)
sample_weight = 1.0 + rng.rand(n_samples)

for (alpha, intercept, solver) in param_grid:

# Ridge with explicit sample_weight
est = Ridge(alpha=alpha, fit_intercept=intercept, solver=solver)
est.fit(X, y, sample_weight=sample_weight)
coefs = est.coef_
inter = est.intercept_

# Closed form of the weighted regularized least square
# theta = (X^T W X + alpha I)^(-1) * X^T W y
W = np.diag(sample_weight)
if intercept is False:
X_aug = X
I = np.eye(n_features)
else:
dummy_column = np.ones(shape=(n_samples, 1))
X_aug = np.concatenate((dummy_column, X), axis=1)
I = np.eye(n_features + 1)
I[0, 0] = 0

cf_coefs = linalg.solve(X_aug.T.dot(W).dot(X_aug) + alpha * I,
X_aug.T.dot(W).dot(y))

if intercept is False:
assert_array_almost_equal(coefs, cf_coefs)
else:
assert_array_almost_equal(coefs, cf_coefs[1:])
assert_almost_equal(inter, cf_coefs[0])


def test_ridge_shapes():
Expand Down Expand Up @@ -570,7 +588,7 @@ def test_ridgecv_sample_weight():
for n_samples, n_features in ((6, 5), (5, 10)):
y = rng.randn(n_samples)
X = rng.randn(n_samples, n_features)
sample_weight = 1 + rng.rand(n_samples)
sample_weight = 1.0 + rng.rand(n_samples)

cv = KFold(5)
ridgecv = RidgeCV(alphas=alphas, cv=cv)
Expand Down

0 comments on commit 1d16ec4

Please sign in to comment.