Skip to content

Commit 8592b6f

Browse files
authored
Support kernel='rbf' in SVM benchmarks (#16)
* Support both binary and multiclass SVM in one native benchmark * native SVM: determine sv_len in binary case * native SVM: Factor out accuracy_score function * Reduce long lines * Remove common_svm.hpp and two_class_svm_bench.cpp * Run svm_bench.cpp through clang-format * Update CLI11 to 1.8.0 * Native SVM: add support for RBF kernel * Update Makefiles * Run svm_bench.cpp through clang-format again * Add support for kernel='rbf' in sklearn/daal4py SVM benches * native SVM: Add C, tol, tau, maxiter command line arguments
1 parent 77fcc63 commit 8592b6f

File tree

10 files changed

+3610
-1610
lines changed

10 files changed

+3610
-1610
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ NATIVE_distances = distances
5353
NATIVE_ridge = ridge
5454
NATIVE_linear = linear
5555
NATIVE_kmeans = kmeans
56-
NATIVE_svm2 = two_class_svm
57-
NATIVE_svm5 = multi_class_svm
56+
NATIVE_svm2 = svm
57+
NATIVE_svm5 = svm
5858
NATIVE_logreg2 = log_reg_lbfgs
5959
NATIVE_logreg5 = log_reg_lbfgs
6060
NATIVE_dfclf2 = decision_forest_clsf

daal4py/svm.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from bench import parse_args, time_mean_min, print_header, print_row, \
77
size_str, accuracy_score
88
import numpy as np
9-
from daal4py import svm_training, svm_prediction, kernel_function_linear, \
9+
from daal4py import svm_training, svm_prediction, \
10+
kernel_function_linear, kernel_function_rbf, \
1011
multi_class_classifier_training, \
1112
multi_class_classifier_prediction
1213
from daal4py.sklearn.utils import getFPType
@@ -194,10 +195,19 @@ def construct_dual_coefs(model, num_classes, X, y):
194195
return support_
195196

196197

198+
def daal_kernel(name, fptype, gamma=1.0):
199+
200+
if name == 'linear':
201+
return kernel_function_linear(fptype=fptype)
202+
else:
203+
sigma = np.sqrt(0.5 / gamma)
204+
return kernel_function_rbf(fptype=fptype, sigma=sigma)
205+
206+
197207
def test_fit(X, y, params):
198208

199209
fptype = getFPType(X)
200-
kf = kernel_function_linear(fptype=fptype)
210+
kf = daal_kernel(params.kernel, fptype, gamma=params.gamma)
201211

202212
if params.n_classes == 2:
203213
y[y == 0] = -1
@@ -246,7 +256,7 @@ def test_fit(X, y, params):
246256
def test_predict(X, training_result, params):
247257

248258
fptype = getFPType(X)
249-
kf = kernel_function_linear(fptype=fptype)
259+
kf = daal_kernel(params.kernel, fptype, gamma=params.gamma)
250260

251261
svm_predict = svm_prediction(
252262
fptype=fptype,
@@ -287,8 +297,10 @@ def main():
287297
help='Input file with labels, in NPY format')
288298
parser.add_argument('-C', dest='C', type=float, default=0.01,
289299
help='SVM slack parameter')
290-
parser.add_argument('--kernel', choices=('linear',), default='linear',
291-
help='SVM kernel function')
300+
parser.add_argument('--kernel', choices=('linear', 'rbf'),
301+
default='linear', help='SVM kernel function')
302+
parser.add_argument('--gamma', type=float, default=None,
303+
help="Parameter for kernel='rbf'")
292304
parser.add_argument('--maxiter', type=int, default=2000,
293305
help='Maximum iterations for the iterative solver. '
294306
'-1 means no limit.')
@@ -301,12 +313,16 @@ def main():
301313
parser.add_argument('--no-shrinking', action='store_false', default=True,
302314
dest='shrinking',
303315
help="Don't use shrinking heuristic")
304-
params = parse_args(parser, loop_types=('fit', 'predict'), prefix='daal4py')
316+
params = parse_args(parser, loop_types=('fit', 'predict'),
317+
prefix='daal4py')
305318

306319
# Load data and cast to float64
307320
X_train = np.load(params.filex.name).astype('f8')
308321
y_train = np.load(params.filey.name).astype('f8')
309322

323+
if params.gamma is None:
324+
params.gamma = 1 / X_train.shape[1]
325+
310326
cache_size_bytes = get_optimal_cache_size(X_train.shape[0],
311327
max_cache=params.max_cache_size)
312328
params.cache_size_mb = cache_size_bytes / 2**20

native/Makefile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5-
BENCHMARKS += distances kmeans linear ridge pca \
6-
two_class_svm multi_class_svm log_reg_lbfgs \
5+
BENCHMARKS += distances kmeans linear ridge pca svm log_reg_lbfgs \
76
decision_forest_regr decision_forest_clsf
87
FOBJ = $(addprefix lbfgsb/,lbfgsb.o linpack.o timer.o)
98
CXXSRCS = $(addsuffix _bench.cpp,$(BENCHMARKS))

native/common.hpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,61 @@ copy_submatrix(dm::NumericTablePtr src,
204204
}
205205

206206

207+
int count_classes(dm::NumericTablePtr y) {
208+
209+
/* compute min and max labels with DAAL */
210+
da::low_order_moments::Batch<double> algorithm;
211+
algorithm.input.set(da::low_order_moments::data, y);
212+
algorithm.compute();
213+
da::low_order_moments::ResultPtr res = algorithm.getResult();
214+
dm::NumericTablePtr min_nt = res->get(da::low_order_moments::minimum);
215+
dm::NumericTablePtr max_nt = res->get(da::low_order_moments::maximum);
216+
217+
int min, max;
218+
dm::BlockDescriptor<double> block;
219+
min_nt->getBlockOfRows(0, 1, dm::readOnly, block);
220+
min = block.getBlockPtr()[0];
221+
max_nt->getBlockOfRows(0, 1, dm::readOnly, block);
222+
max = block.getBlockPtr()[0];
223+
return 1 + max - min;
224+
225+
}
226+
227+
228+
size_t count_same_labels(dm::NumericTablePtr y1, dm::NumericTablePtr y2,
229+
double tol = 1e-6) {
230+
231+
size_t equal_counter = 0;
232+
size_t n_rows = std::min(y1->getNumberOfRows(), y2->getNumberOfRows());
233+
dm::BlockDescriptor<double> block_y1, block_y2;
234+
y1->getBlockOfRows(0, n_rows, dm::readOnly, block_y1);
235+
y2->getBlockOfRows(0, n_rows, dm::readOnly, block_y2);
236+
237+
double *ptr1 = block_y1.getBlockPtr();
238+
double *ptr2 = block_y2.getBlockPtr();
239+
240+
for (size_t i = 0; i < n_rows; i++) {
241+
if (abs(ptr1[i] - ptr2[i]) < tol) {
242+
equal_counter++;
243+
}
244+
}
245+
y1->releaseBlockOfRows(block_y1);
246+
y2->releaseBlockOfRows(block_y2);
247+
248+
return equal_counter;
249+
250+
}
251+
252+
253+
double accuracy_score(dm::NumericTablePtr y1, dm::NumericTablePtr y2,
254+
double tol = 1e-6) {
255+
256+
double n_rows = std::min(y1->getNumberOfRows(), y2->getNumberOfRows());
257+
return (double) count_same_labels(y1, y2) / n_rows;
258+
259+
}
260+
261+
207262
/*
208263
* Generate an array of random numbers.
209264
*/

native/common_svm.hpp

Lines changed: 0 additions & 27 deletions
This file was deleted.

0 commit comments

Comments
 (0)