Skip to content

Commit

Permalink
scikit_wrappers.py bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zionsteiner authored Dec 6, 2020
1 parent 86c4987 commit 4fd18e3
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions scikit_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,40 +172,40 @@ def fit_classifier(self, features, y):
if train_size // nb_classes < 5 or train_size < 50 or self.penalty is not None:
return self.classifier.fit(features, y)
else:
grid_search = sklearn.model_selection.GridSearchCV(
self.classifier, {
'C': [
0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000,
numpy.inf
],
'kernel': ['rbf'],
'degree': [3],
'gamma': ['scale'],
'coef0': [0],
'shrinking': [True],
'probability': [False],
'tol': [0.001],
'cache_size': [200],
'class_weight': [None],
'verbose': [False],
'max_iter': [10000000],
'decision_function_shape': ['ovr'],
'random_state': [None]
},
cv=5, iid=False, n_jobs=5
grid_search = sklearn.model_selection.GridSearchCV(
self.classifier, {
'C': [
0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000,
numpy.inf
],
'kernel': ['rbf'],
'degree': [3],
'gamma': ['scale'],
'coef0': [0],
'shrinking': [True],
'probability': [False],
'tol': [0.001],
'cache_size': [200],
'class_weight': [None],
'verbose': [False],
'max_iter': [10000000],
'decision_function_shape': ['ovr'],
'random_state': [None]
},
cv=5, iid=False, n_jobs=5
)
if train_size <= 10000:
grid_search.fit(features, y)
else:
# If the training set is too large, subsample 10000 train
# examples
split = sklearn.model_selection.train_test_split(
features, y,
train_size=10000, random_state=0, stratify=y
)
if train_size <= 10000:
grid_search.fit(features, y)
else:
# If the training set is too large, subsample 10000 train
# examples
split = sklearn.model_selection.train_test_split(
features, y,
train_size=10000, random_state=0, stratify=y
)
grid_search.fit(split[0], split[2])
self.classifier = grid_search.best_estimator_
return self.classifier
grid_search.fit(split[0], split[2])
self.classifier = grid_search.best_estimator_
return self.classifier

def fit_encoder(self, X, y=None, save_memory=False, verbose=False):
"""
Expand Down

0 comments on commit 4fd18e3

Please sign in to comment.