Skip to content

Commit 8101692

Browse files
committed
wip: multitask api
1 parent 4da9b41 commit 8101692

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

examples/multitask_example.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
from sklearn.datasets import make_classification
3+
from stlearn import MultiTaskEstimator
4+
from sklearn.linear_model import (MultiTaskLassoCV,
5+
MultiTaskElasticNet,
6+
MultiTaskElasticNetCV,
7+
LassoCV, LogisticRegression)
8+
from sklearn.metrics import accuracy_score
9+
10+
X, y = make_classification(random_state=25)
11+
n = 10
12+
13+
Y = np.array(n*[y]).T
14+
15+
mt = MultiTaskEstimator(
16+
# estimator=MultiTaskLassoCV(alphas=np.logspace(-3, 3, 7)),
17+
estimator=MultiTaskElasticNetCV(),
18+
output_types=n*['binary'])
19+
20+
mt.fit(X, Y)
21+
print(mt.score(X, Y))
File renamed without changes.

stlearn/multitask.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from sklearn.base import BaseEstimator, TransformerMixin
1111
from sklearn.metrics import accuracy_score, r2_score
1212
from sklearn.externals.joblib import Memory, Parallel, delayed
13-
from sklearn.linear_model import MultiTaskLassoCV, MultiTaskElasticNetCV
13+
from sklearn.linear_model import (
14+
MultiTaskLassoCV, MultiTaskElasticNetCV, LogisticRegression)
1415

1516

1617
class MultiTaskEstimator(BaseEstimator, TransformerMixin):
@@ -44,12 +45,12 @@ def _decision_function(self, X):
4445

4546
def predict(self, X):
4647
# predict multiple outputs
47-
Ypred = self.estimator._decision_function(X)
48+
Ypred = self._decision_function(X)
4849
for i in range(self.n_outputs):
4950
if self.output_types[i] == 'binary':
50-
# binarize
51+
# binarize classification results
5152
labels = np.zeros(Ypred[:, i].shape)
52-
labels[Ypred[:, i] > 0.] = 1
53+
labels[Ypred[:, i] >= 0.5] = 1
5354
Ypred[:, i] = labels
5455
return Ypred
5556

0 commit comments

Comments
 (0)