Skip to content

Commit 9271d12

Browse files
authored
prediction (#69)
1 parent 3be1c18 commit 9271d12

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

report_generator/report_generator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def create_list(res_entry, props_list):
176176

177177
stages_splitter = {
178178
'training': ['training', 'computation'],
179-
'inference': ['prediction', 'transformation', 'search']
179+
'inference': ['prediction', 'transformation', 'search', 'predict_proba']
180180
}
181181

182182
for stage_key in stages_splitter.keys():

sklearn_bench/svm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def main():
2424
from sklearn.svm import SVC
2525

2626
X_train, X_test, y_train, y_test = bench.load_data(params)
27+
y_train = np.asfortranarray(y_train).ravel()
2728

2829
if params.gamma is None:
2930
params.gamma = 1.0 / X_train.shape[1]
@@ -46,7 +47,7 @@ def main():
4647
def metric_call(x, y): return bench.log_loss(x, y)
4748
clf_predict = clf.predict_proba
4849
else:
49-
state_predict = 'predict'
50+
state_predict = 'prediction'
5051
accuracy_type = 'accuracy[%]'
5152
def metric_call(x, y): return bench.accuracy_score(x, y)
5253
clf_predict = clf.predict

0 commit comments

Comments
 (0)