Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions dask_ml/linear_model/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ..metrics import r2_score
from ..utils import check_array
from .utils import lr_prob_stack

_base_doc = textwrap.dedent(
"""\
Expand Down Expand Up @@ -228,7 +229,7 @@ def decision_function(self, X):

Returns
-------
T : array-like, shape = [n_samples, n_classes]
T : array-like, shape = [n_samples,]
The confidence score of the sample for each class in the model.
"""
X_ = self._check_array(X)
Expand All @@ -246,7 +247,7 @@ def predict(self, X):
C : array, shape = [n_samples,]
Predicted class labels for each sample
"""
return self.predict_proba(X) > 0.5 # TODO: verify, multi_class broken
return self.predict_proba(X)[:, 1] > 0.5 # TODO: verify, multi_class broken

def predict_proba(self, X):
"""Probability estimates for samples in X.
Expand All @@ -260,7 +261,9 @@ def predict_proba(self, X):
T : array-like, shape = [n_samples, n_classes]
The probability of the sample for each class in the model.
"""
return sigmoid(self.decision_function(X))
# TODO: more work needed here to support multi_class
prob = sigmoid(self.decision_function(X))
return lr_prob_stack(prob)

def score(self, X, y):
"""The mean accuracy on the given data and labels
Expand Down
10 changes: 10 additions & 0 deletions dask_ml/linear_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,13 @@ def add_intercept(X): # noqa: F811
if "intercept" in columns:
raise ValueError("'intercept' column already in 'X'")
return X.assign(intercept=1)[["intercept"] + list(columns)]


@dispatch(np.ndarray) # noqa: F811
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually just use np.vstack now, since Dask implemented NEP 18.

def lr_prob_stack(prob): # noqa: F811
return np.vstack([1 - prob, prob]).T


@dispatch(da.Array) # noqa: F811
def lr_prob_stack(prob): # noqa: F811
return da.vstack([1 - prob, prob]).T
8 changes: 8 additions & 0 deletions tests/linear_model/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,11 @@ def test_dataframe_warns_about_chunks(fit_intercept):
clf.fit(X.values, y.values)
clf.fit(X.to_dask_array(), y.to_dask_array())
clf.fit(X.to_dask_array(lengths=True), y.to_dask_array(lengths=True))


def test_logistic_predict_proba_shape():
X, y = make_classification(n_samples=100, n_features=5, chunks=50)
lr = LogisticRegression()
lr.fit(X, y)
prob = lr.predict_proba(X)
assert prob.shape == (100, 2)