Skip to content

Commit 78cb478

Browse files
committed
wip: multitask api
1 parent 2e28a48 commit 78cb478

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

examples/multitask_example.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,25 @@
44
from sklearn.linear_model import (MultiTaskLassoCV,
55
MultiTaskElasticNet,
66
MultiTaskElasticNetCV,
7+
ElasticNetCV,
78
LassoCV, LogisticRegression)
89
from sklearn.metrics import accuracy_score
10+
from sklearn.model_selection import ShuffleSplit
911

10-
X, y = make_classification(random_state=25)
12+
X, y = make_classification(n_samples=500, random_state=42)
1113
n = 10
12-
1314
Y = np.array(n*[y]).T
1415

1516
mt = MultiTaskEstimator(
16-
estimator=MultiTaskLassoCV(alphas=np.logspace(-3, 3, 7)),
17+
estimator=MultiTaskElasticNetCV(alphas=np.logspace(-3, 3, 7)),
1718
output_types=n//2*['binary']+n//2*['continuous'])
1819

19-
mt.fit(X, Y)
20-
print(mt.score(X, Y))
20+
ls = LogisticRegression()
21+
22+
ss = ShuffleSplit(n_splits=10, test_size=0.2, random_state=42)
23+
24+
for train, test in ss.split(X):
25+
mt.fit(X[train], Y[train])
26+
ls.fit(X[train], y[train])
27+
print(ls.score(X[test], y[test]))
28+
print(mt.score(X[test], Y[test]))

stlearn/multitask.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,26 @@
99
import numpy as np
1010
from sklearn.base import BaseEstimator, TransformerMixin
1111
from sklearn.metrics import accuracy_score, r2_score
12-
from sklearn.externals.joblib import Memory, Parallel, delayed
12+
from sklearn.externals.joblib import Parallel, delayed
1313
from sklearn.linear_model import (
1414
MultiTaskLassoCV, MultiTaskElasticNetCV, LogisticRegression)
1515

1616

1717
class MultiTaskEstimator(BaseEstimator, TransformerMixin):
1818
"""MultiTask estimator for multiple (continuous / discrete) outputs.
19+
20+
Parameters
21+
----------
22+
estimator : Multitask scikit-learn estimator, can be
23+
{"MultiTaskLasso", "MultiTaskLassoCV",
24+
"MultiTaskElasticNet", "MultiTaskElasticNetCV"}
25+
26+
output_types : shape = (n_outputs,) type of each output, can be
27+
{"binary", "continuous"}
1928
"""
2029

21-
def __init__(self, estimator=None,
22-
memory=Memory(cachedir=None), memory_level=0,
23-
n_jobs=1, output_types=None):
30+
def __init__(self, estimator=None, output_types=None):
2431
self.estimator = estimator
25-
self.memory = memory
26-
self.memory_level = memory_level
27-
self.n_jobs = n_jobs
2832
# check if output types are okay
2933
for output in output_types:
3034
if output not in ['binary', 'continuous']:

0 commit comments

Comments
 (0)