Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support parameter 'class_weight' and method 'decision_function' in LinearSVC #5364

Merged
merged 20 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
a35f31b
add class_weight and decision_function to LinearSVC
mfoerste4 Apr 17, 2023
7a5562a
review suggestions, added python test
mfoerste4 Apr 18, 2023
90cc392
Merge branch 'branch-23.06' into linear_svm_features
mfoerste4 Apr 18, 2023
5dc4bba
add test comments
mfoerste4 Apr 19, 2023
3cea6ca
Merge branch 'linear_svm_features' of github.com:mfoerste4/cuml into …
mfoerste4 Apr 19, 2023
bfce7c7
review suggestions
mfoerste4 Apr 20, 2023
b390301
Merge branch 'rapidsai:branch-23.06' into linear_svm_features
mfoerste4 Apr 25, 2023
1570b46
review suggestions
mfoerste4 Apr 26, 2023
6d7ab43
Merge branch 'branch-23.06' into linear_svm_features
mfoerste4 Apr 26, 2023
b900776
make ci happy
mfoerste4 Apr 26, 2023
015828f
Merge branch 'linear_svm_features' of github.com:mfoerste4/cuml into …
mfoerste4 Apr 26, 2023
21c901d
pre-commit 2nd pass
mfoerste4 Apr 26, 2023
4318607
Merge branch 'branch-23.06' into linear_svm_features
tfeher May 2, 2023
265f82f
Merge branch 'branch-23.06' into linear_svm_features
tfeher May 3, 2023
d987b43
Merge branch 'branch-23.06' into linear_svm_features
mfoerste4 May 5, 2023
30012f7
Merge branch 'branch-23.06' into linear_svm_features
mfoerste4 May 8, 2023
d308c4e
Merge branch 'branch-23.06' into linear_svm_features
dantegd May 8, 2023
6f34263
Merge branch 'branch-23.06' into linear_svm_features
tfeher May 9, 2023
62516f1
Merge branch 'branch-23.06' into linear_svm_features
mfoerste4 May 9, 2023
509b1c7
Merge branch 'branch-23.06' into linear_svm_features
tfeher May 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion cpp/include/cuml/svm/linear.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -162,6 +162,25 @@ struct LinearSVMModel {
const std::size_t nCols,
T* out);

/**
* @brief Calculate decision function value for samples in input.
* @param [in] handle the cuML handle.
* @param [in] params the model parameters.
* @param [in] model the trained model.
* @param [in] X the input data matrix of size (nRows, nCols) in column-major format.
* @param [in] nRows number of vectors
* @param [in] nCols number of features
* @param [out] out the decision function value of size (nRows, n_classes <= 2 ? 1 : n_classes) in
* row-major format.
*/
static void decisionFunction(const raft::handle_t& handle,
const LinearSVMParams& params,
const LinearSVMModel<T>& model,
const T* X,
const std::size_t nRows,
const std::size_t nCols,
T* out);

/**
* @brief For SVC, predict the probabilities for each outcome.
*
Expand Down
21 changes: 21 additions & 0 deletions cpp/src/svm/linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,27 @@ LinearSVMModel<T> LinearSVMModel<T>::fit(const raft::handle_t& handle,
return model;
}

template <typename T>
void LinearSVMModel<T>::decisionFunction(const raft::handle_t& handle,
const LinearSVMParams& params,
const LinearSVMModel<T>& model,
const T* X,
const std::size_t nRows,
const std::size_t nCols,
T* out)
{
ASSERT(!isRegression(params.loss), "Decision function is not available for the regression model");
predictLinear(handle,
X,
model.w,
nRows,
nCols,
model.coefCols(),
params.fit_intercept,
out,
handle.get_stream());
}

template <typename T>
void LinearSVMModel<T>::predict(const raft::handle_t& handle,
const LinearSVMParams& params,
Expand Down
49 changes: 47 additions & 2 deletions python/cuml/svm/linear.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -115,6 +115,14 @@ cdef extern from "cuml/svm/linear.hpp" namespace "ML::SVM" nogil:
const T* X,
const size_t nRows, const size_t nCols, T* out) except +

@staticmethod
void decisionFunction(
const handle_t& handle,
const LinearSVMParams& params,
const LinearSVMModel[T]& model,
const T* X,
const size_t nRows, const size_t nCols, T* out) except +

@staticmethod
void predictProba(
const handle_t& handle,
Expand Down Expand Up @@ -244,7 +252,7 @@ cdef class LinearSVMWrapper:
cudaMemcpyAsync(
<void*><uintptr_t>target.ptr,
<void*><uintptr_t>source.ptr,
<size_t>(source.nbytes),
<size_t>(source.size),
cudaMemcpyKind.cudaMemcpyDeviceToDevice,
stream.value())
if synchronize:
Expand Down Expand Up @@ -447,6 +455,35 @@ cdef class LinearSVMWrapper:
<const double*><uintptr_t>X.ptr,
X.shape[0], X.shape[1],
<double*><uintptr_t>y.ptr)
else:
raise TypeError('Input data type must be float32 or float64')

return y

def decision_function(self, X: CumlArray) -> CumlArray:
n_classes = self.classes_.shape[0]
# special handling of binary case
shape = (X.shape[0],) if n_classes <= 2 else (X.shape[0], n_classes)
y = CumlArray.empty(
shape=shape,
dtype=self.dtype, order='C')

if self.dtype == np.float32:
LinearSVMModel[float].decisionFunction(
deref(self.handle),
self.params,
self.model.float32,
<const float*><uintptr_t>X.ptr,
X.shape[0], X.shape[1],
<float*><uintptr_t>y.ptr)
elif self.dtype == np.float64:
LinearSVMModel[double].decisionFunction(
deref(self.handle),
self.params,
self.model.float64,
<const double*><uintptr_t>X.ptr,
X.shape[0], X.shape[1],
<double*><uintptr_t>y.ptr)
else:
raise TypeError('Input data type should be float32 or float64')

Expand Down Expand Up @@ -659,6 +696,14 @@ class LinearSVM(Base, metaclass=WithReexportedParams):
self.__sync_model()
return self.model_.predict(X_m)

def decision_function(self, X, convert_dtype=True) -> CumlArray:
convert_to_dtype = self.dtype if convert_dtype else None
X_m, n_rows, n_cols, _ = input_to_cuml_array(
X, check_dtype=self.dtype,
convert_to_dtype=convert_to_dtype)
self.__sync_model()
return self.model_.decision_function(X_m)

def predict_proba(self, X, log=False, convert_dtype=True) -> CumlArray:
convert_to_dtype = self.dtype if convert_dtype else None
X_m, n_rows, n_cols, _ = input_to_cuml_array(
Expand Down
19 changes: 19 additions & 0 deletions python/cuml/svm/linear_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from cuml.internals.mixins import ClassifierMixin
from cuml.svm.linear import LinearSVM, LinearSVM_defaults # noqa: F401
from cuml.svm.svc import apply_class_weight


__all__ = ["LinearSVC"]

Expand Down Expand Up @@ -76,6 +78,10 @@ class LinearSVC(LinearSVM, ClassifierMixin):
} (default = {LinearSVM_defaults.lbfgs_memory})
Number of vectors approximating the hessian for the underlying QN
solver (l-bfgs).
class_weight : dict or string (default=None)
Weights to modify the parameter C for class i to class_weight[i]*C. The
string 'balanced' is also accepted, in which case ``class_weight[i] =
n_samples / (n_classes * n_samples_of_class[i])``
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
Expand Down Expand Up @@ -171,6 +177,7 @@ def get_param_names(self):
return list(
{
"handle",
"class_weight",
"verbose",
"penalty",
"loss",
Expand All @@ -186,3 +193,15 @@ def get_param_names(self):
"multi_class",
}.union(super().get_param_names())
)

def fit(self, X, y, sample_weight=None, convert_dtype=True) -> "LinearSVM":
sample_weight = apply_class_weight(
self.handle,
sample_weight,
self.class_weight,
y,
self.verbose,
self.output_type,
X.dtype,
)
return super(LinearSVC, self).fit(X, y, sample_weight, convert_dtype)
136 changes: 79 additions & 57 deletions python/cuml/svm/svc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ if has_sklearn():
from cuml.multiclass import MulticlassClassifier
from sklearn.calibration import CalibratedClassifierCV


cdef extern from "raft/distance/distance_types.hpp" \
namespace "raft::distance::kernels":
enum KernelType:
Expand Down Expand Up @@ -111,6 +112,83 @@ cdef extern from "cuml/svm/svc.hpp" namespace "ML::SVM" nogil:
math_t *preds, math_t buffer_size, bool predict_class) except +


def apply_class_weight(handle, sample_weight, class_weight, y, verbose, output_type, dtype) -> CumlArray:
"""
Scale the sample weights with the class weights.

Returns the modified sample weights, or None if neither class weights
nor sample weights are defined. The returned weights are defined as

sample_weight[i] = class_weight[y[i]] * sample_weight[i].

Parameters:
-----------
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model.
sample_weight: array-like (device or host), shape = (n_samples, 1)
sample weights or None if not given
class_weight : dict or string (default=None)
Weights to modify the parameter C for class i to class_weight[i]*C. The
string 'balanced' is also accepted, in which case ``class_weight[i] =
n_samples / (n_classes * n_samples_of_class[i])``
y: array of floats or doubles, shape = (n_samples, 1)
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
output_type : {{'input', 'array', 'dataframe', 'series', 'df_obj', \
'numba', 'cupy', 'numpy', 'cudf', 'pandas'}}, default=None
Return results and set estimator attributes to the indicated output
type. If None, the output type set at the module level
(`cuml.global_settings.output_type`) will be used. See
:ref:`output-data-type-configuration` for more info.
dtype : dtype for sample_weights

Returns
--------
sample_weight: device array shape = (n_samples, 1) or None
"""
if class_weight is None:
return sample_weight

if type(y) is CumlArray:
y_m = y
else:
y_m, _, _, _ = input_to_cuml_array(y, check_cols=1)

le = LabelEncoder(handle=handle,
verbose=verbose,
output_type=output_type)
labels = y_m.to_output(output_type='series')
encoded_labels = cp.asarray(le.fit_transform(labels))
n_samples = y_m.shape[0]

# Define class weights for the encoded labels
if class_weight == 'balanced':
counts = cp.asnumpy(cp.bincount(encoded_labels))
n_classes = len(counts)
weights = n_samples / (n_classes * counts)
class_weight = {i: weights[i] for i in range(n_classes)}
else:
keys = class_weight.keys()
keys_series = cudf.Series(keys)
encoded_keys = le.transform(cudf.Series(keys)).values_host
class_weight = {enc_key: class_weight[key]
for enc_key, key in zip(encoded_keys, keys)}

if sample_weight is None:
sample_weight = cp.ones(y_m.shape, dtype=dtype)
else:
sample_weight, _, _, _ = \
input_to_cupy_array(sample_weight, convert_to_dtype=dtype,
check_rows=n_samples, check_cols=1)

for label, weight in class_weight.items():
sample_weight[encoded_labels==label] *= weight

return sample_weight


class SVC(SVMBase,
ClassifierMixin):
"""
Expand Down Expand Up @@ -313,62 +391,6 @@ class SVC(SVMBase,
def intercept_(self, value):
self._intercept_ = value

@cuml.internals.api_base_return_array_skipall
def _apply_class_weight(self, sample_weight, y_m) -> CumlArray:
"""
Scale the sample weights with the class weights.

Returns the modified sample weights, or None if neither class weights
nor sample weights are defined. The returned weights are defined as

sample_weight[i] = class_weight[y[i]] * sample_weight[i].

Parameters:
-----------
sample_weight: array-like (device or host), shape = (n_samples, 1)
sample weights or None if not given
y_m: device array of floats or doubles, shape = (n_samples, 1)
Array of target labels already copied to the device.

Returns
--------
sample_weight: device array shape = (n_samples, 1) or None
"""
if self.class_weight is None:
return sample_weight

le = LabelEncoder(handle=self.handle,
verbose=self.verbose,
output_type=self.output_type)
labels = y_m.to_output(output_type='series')
encoded_labels = cp.asarray(le.fit_transform(labels))

# Define class weights for the encoded labels
if self.class_weight == 'balanced':
counts = cp.asnumpy(cp.bincount(encoded_labels))
n_classes = len(counts)
n_samples = y_m.shape[0]
weights = n_samples / (n_classes * counts)
class_weight = {i: weights[i] for i in range(n_classes)}
else:
keys = self.class_weight.keys()
keys_series = cudf.Series(keys)
encoded_keys = le.transform(cudf.Series(keys)).values_host
class_weight = {enc_key: self.class_weight[key]
for enc_key, key in zip(encoded_keys, keys)}

if sample_weight is None:
sample_weight = cp.ones(y_m.shape, dtype=self.dtype)
else:
sample_weight, _, _, _ = \
input_to_cupy_array(sample_weight, convert_to_dtype=self.dtype,
check_rows=self.n_rows, check_cols=1)

for label, weight in class_weight.items():
sample_weight[encoded_labels==label] *= weight

return sample_weight

def _get_num_classes(self, y):
"""
Determine the number of unique classes in y.
Expand Down Expand Up @@ -468,7 +490,7 @@ class SVC(SVMBase,

cdef uintptr_t y_ptr = y_m.ptr

sample_weight = self._apply_class_weight(sample_weight, y_m)
sample_weight = apply_class_weight(self.handle, sample_weight, self.class_weight, y_m, self.verbose, self.output_type, self.dtype)
cdef uintptr_t sample_weight_ptr = <uintptr_t> nullptr
if sample_weight is not None:
sample_weight_m, _, _, _ = \
Expand Down
Loading