Skip to content

Commit

Permalink
fix to a minor bug with intercept
Browse files Browse the repository at this point in the history
  • Loading branch information
dsullivan7 committed Oct 13, 2014
1 parent 20a5952 commit 04f6d12
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 20 deletions.
3 changes: 2 additions & 1 deletion sklearn/linear_model/stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ def _partial_fit(self, X, y, alpha, C,
raise ValueError("The number of class labels must be "
"greater than one.")


return self

def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
Expand Down Expand Up @@ -452,6 +451,7 @@ def _fit_binary(self, X, y, alpha, C, sample_weight,
self.intercept_ = self.average_intercept_
else:
self.coef_ = self.standard_coef_.reshape(1, -1)
self.standard_intercept_ = np.atleast_1d(intercept)
self.intercept_ = self.standard_intercept_
else:
self.coef_ = coef.reshape(1, -1)
Expand Down Expand Up @@ -484,6 +484,7 @@ def _fit_multiclass(self, X, y, alpha, C, learning_rate,
self.intercept_ = self.average_intercept_
else:
self.coef_ = self.standard_coef_
self.standard_intercept_ = np.atleast_1d(intercept)
self.intercept_ = self.standard_intercept_

def partial_fit(self, X, y, classes=None, sample_weight=None):
Expand Down
31 changes: 12 additions & 19 deletions sklearn/linear_model/tests/test_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,25 +208,18 @@ def test_plain_has_no_average_attr(self):
assert_false(hasattr(clf, 'standard_coef_'))

def test_late_onset_averaging_not_reached(self):
eta0 = .001
clf1 = self.factory(average=12, learning_rate="constant",
eta0=eta0, n_iter=2)
clf2 = self.factory(learning_rate="constant", eta0=eta0, n_iter=2)

clf1.fit(X, Y)
clf2.fit(X, Y)

assert_array_almost_equal(clf1.coef_, clf2.coef_)
assert_almost_equal(clf1.intercept_, clf1.intercept_)

clf1 = self.factory(average=13, learning_rate="constant",
eta0=eta0, n_iter=2)
clf1.fit(X, Y)

assert_array_almost_equal(clf1.coef_, clf2.coef_,
decimal=16)
assert_almost_equal(clf1.intercept_, clf2.intercept_,
decimal=16)
clf1 = self.factory(average=600)
clf2 = self.factory()
for _ in range(100):
if isinstance(clf1, SGDClassifier):
clf1.partial_fit(X, Y, classes=np.unique(Y))
clf2.partial_fit(X, Y, classes=np.unique(Y))
else:
clf1.partial_fit(X, Y)
clf2.partial_fit(X, Y)

assert_array_almost_equal(clf1.coef_, clf2.coef_, decimal=10)
assert_almost_equal(clf1.intercept_, clf1.intercept_, decimal=10)

def test_late_onset_averaging_reached(self):
eta0 = .001
Expand Down

0 comments on commit 04f6d12

Please sign in to comment.