Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Support multiple classes in multi-node-multi-gpu logistic regression, from C++, Cython, to Dask Python class #5565

Merged
merged 13 commits into from
Sep 29, 2023
Merged
12 changes: 12 additions & 0 deletions cpp/include/cuml/linear_model/qn_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,24 @@

#include <cumlprims/opg/matrix/data.hpp>
#include <cumlprims/opg/matrix/part_descriptor.hpp>
#include <vector>
using namespace MLCommon;

namespace ML {
namespace GLM {
namespace opg {

/**
* @brief Calculate unique class labels across multiple GPUs in a multi-node environment.
* @param[in] handle: the internal cuml handle object
* @param[in] input_desc: PartDescriptor object for the input
* @param[in] labels: labels data
* @returns host vector that stores the distinct labels
*/
std::vector<float> getUniquelabelsMG(const raft::handle_t& handle,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels);

/**
* @brief performs MNMG fit operation for the logistic regression using quasi newton methods
* @param[in] handle: the internal cuml handle object
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/glm/qn/mg/qn_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ inline void qn_fit_x_mg(const raft::handle_t& handle,
ML::GLM::opg::qn_fit_mg<T, decltype(loss)>(
handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks);
} break;
case QN_LOSS_SOFTMAX: {
ASSERT(C > 2, "qn_mg.cuh: softmax invalid C");
ML::GLM::detail::Softmax<T> loss(handle, D, C, pams.fit_intercept);
ML::GLM::opg::qn_fit_mg<T, decltype(loss)>(
handle, pams, loss, X, y, Z, w0_data, f, num_iters, n_samples, rank, n_ranks);
} break;
default: {
ASSERT(false, "qn_mg.cuh: unknown loss function type (id = %d).", pams.loss);
}
Expand Down
66 changes: 55 additions & 11 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,59 @@
#include <cuml/linear_model/qn.h>
#include <cuml/linear_model/qn_mg.hpp>
#include <raft/core/comms.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/error.hpp>
#include <raft/core/handle.hpp>
#include <raft/label/classlabels.cuh>
#include <raft/util/cudart_utils.hpp>
#include <vector>
using namespace MLCommon;

namespace ML {
namespace GLM {
namespace opg {

template <typename T>
std::vector<T> distinct_mg(const raft::handle_t& handle, T* y, size_t n)
{
cudaStream_t stream = handle.get_stream();
raft::comms::comms_t const& comm = raft::resource::get_comms(handle);
int rank = comm.get_rank();
int n_ranks = comm.get_size();

rmm::device_uvector<T> unique_y(0, stream);
raft::label::getUniquelabels(unique_y, y, n, stream);

rmm::device_uvector<size_t> recv_counts(n_ranks, stream);
auto send_count = raft::make_device_scalar<size_t>(handle, unique_y.size());
comm.allgather(send_count.data_handle(), recv_counts.data(), 1, stream);
comm.sync_stream(stream);

std::vector<size_t> recv_counts_host(n_ranks);
raft::copy(recv_counts_host.data(), recv_counts.data(), n_ranks, stream);

std::vector<size_t> displs(n_ranks);
size_t pos = 0;
for (int i = 0; i < n_ranks; ++i) {
displs[i] = pos;
pos += recv_counts_host[i];
}

rmm::device_uvector<T> recv_buff(displs.back() + recv_counts_host.back(), stream);
comm.allgatherv(
unique_y.data(), recv_buff.data(), recv_counts_host.data(), displs.data(), stream);
comm.sync_stream(stream);

rmm::device_uvector<T> global_unique_y(0, stream);
int n_distinct =
raft::label::getUniquelabels(global_unique_y, recv_buff.data(), recv_buff.size(), stream);

std::vector<T> global_unique_y_host(global_unique_y.size());
raft::copy(global_unique_y_host.data(), global_unique_y.data(), global_unique_y.size(), stream);

return global_unique_y_host;
}

template <typename T>
void qnFit_impl(const raft::handle_t& handle,
const qn_params& pams,
Expand All @@ -46,17 +90,6 @@ void qnFit_impl(const raft::handle_t& handle,
int rank,
int n_ranks)
{
switch (pams.loss) {
case QN_LOSS_LOGISTIC: {
RAFT_EXPECTS(
C == 2,
"qn_mg.cu: only the LOGISTIC loss is supported currently. The number of classes must be 2");
} break;
default: {
RAFT_EXPECTS(false, "qn_mg.cu: unknown loss function type (id = %d).", pams.loss);
}
}

auto X_simple = SimpleDenseMat<T>(X, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR);

ML::GLM::opg::qn_fit_x_mg(handle,
Expand Down Expand Up @@ -113,6 +146,17 @@ void qnFit_impl(raft::handle_t& handle,
input_desc.uniqueRanks().size());
}

std::vector<float> getUniquelabelsMG(const raft::handle_t& handle,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels)
{
RAFT_EXPECTS(labels.size() == 1,
"getUniqueLabelsMG currently does not accept more than one data chunk");
Matrix::Data<float>* data_y = labels[0];
int n_rows = input_desc.totalElementsOwnedBy(input_desc.rank);
return distinct_mg<float>(handle, data_y->ptr, n_rows);
}

void qnFit(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_data,
Matrix::PartDescriptor& input_desc,
Expand Down
9 changes: 8 additions & 1 deletion python/cuml/dask/linear_model/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,11 @@ def _create_model(sessionId, datatype, **kwargs):
def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank):
inp_X = concatenate([X for X, _ in data])
inp_y = concatenate([y for _, y in data])
return f.fit([(inp_X, inp_y)], n_rows, n_cols, partsToSizes, rank)
n_ranks = max([p[0] for p in partsToSizes]) + 1
aggregated_partsToSizes = [[i, 0] for i in range(n_ranks)]
for p in partsToSizes:
aggregated_partsToSizes[p[0]][1] += p[1]

return f.fit(
[(inp_X, inp_y)], n_rows, n_cols, aggregated_partsToSizes, rank
)
28 changes: 23 additions & 5 deletions python/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,18 @@ cdef extern from "cuml/linear_model/qn_mg.hpp" namespace "ML::GLM::opg" nogil:
float *f,
int *num_iters) except +

cdef vector[float] getUniquelabelsMG(
const handle_t& handle,
PartDescriptor &input_desc,
vector[floatData_t*] labels) except+


class LogisticRegressionMG(MGFitMixin, LogisticRegression):

def __init__(self, **kwargs):
super(LogisticRegressionMG, self).__init__(**kwargs)
if self.penalty != "l2" and self.penalty != "none":
assert False, "Currently only support 'l2' and 'none' penalty"

@property
@cuml.internals.api_base_return_array_skipall
Expand All @@ -102,8 +109,8 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):

self.solver_model.coef_ = value

def prepare_for_fit(self, n_classes):
self.solver_model.qnparams = QNParams(
def create_qnparams(self):
return QNParams(
loss=self.loss,
penalty_l1=self.l1_strength,
penalty_l2=self.l2_strength,
Expand All @@ -118,8 +125,11 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
penalty_normalized=self.penalty_normalized
)

def prepare_for_fit(self, n_classes):
self.solver_model.qnparams = self.create_qnparams()

# modified
qnpams = self.qnparams.params
qnpams = self.solver_model.qnparams.params

# modified qnp
solves_classification = qnpams['loss'] in {
Expand Down Expand Up @@ -174,8 +184,14 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
cdef float objective32
cdef int num_iters

# TODO: calculate _num_classes at runtime
self._num_classes = 2
cdef vector[float] c_classes_
c_classes_ = getUniquelabelsMG(
handle_[0],
deref(<PartDescriptor*><uintptr_t>input_desc),
deref(<vector[floatData_t*]*><uintptr_t>y))
self.classes_ = np.sort(list(c_classes_)).astype('float32')

self._num_classes = len(self.classes_)
self.loss = "sigmoid" if self._num_classes <= 2 else "softmax"
self.prepare_for_fit(self._num_classes)
cdef uintptr_t mat_coef_ptr = self.coef_.ptr
Expand All @@ -194,6 +210,8 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):
self._num_classes,
<float*> &objective32,
<int*> &num_iters)
else:
assert False, "dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589"

self.solver_model._calc_intercept()

Expand Down
102 changes: 97 additions & 5 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,13 @@ def _prep_training_data(c, X_train, y_train, partitions_per_worker):
return X_train_df, y_train_df


def make_classification_dataset(datatype, nrows, ncols, n_info):
def make_classification_dataset(datatype, nrows, ncols, n_info, n_classes=2):
X, y = make_classification(
n_samples=nrows, n_features=ncols, n_informative=n_info, random_state=0
n_samples=nrows,
n_features=ncols,
n_informative=n_info,
n_classes=n_classes,
random_state=0,
)
X = X.astype(datatype)
y = y.astype(datatype)
Expand Down Expand Up @@ -176,6 +180,16 @@ def imp():

assert_array_equal(preds, y, strict=True)

# assert error on float64
X = X.astype(np.float64)
y = y.astype(np.float64)
X_df, y_df = _prep_training_data(client, X, y, n_parts)
with pytest.raises(
RuntimeError,
match="dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589",
):
lr.fit(X_df, y_df)


def test_lbfgs_init(client):
def imp():
Expand Down Expand Up @@ -267,6 +281,7 @@ def test_lbfgs(
delayed,
client,
penalty="l2",
n_classes=2,
):
tolerance = 0.005

Expand All @@ -283,7 +298,9 @@ def imp():
n_info = 5
nrows = int(nrows)
ncols = int(ncols)
X, y = make_classification_dataset(datatype, nrows, ncols, n_info)
X, y = make_classification_dataset(
datatype, nrows, ncols, n_info, n_classes=n_classes
)

X_df, y_df = _prep_training_data(client, X, y, n_parts)

Expand All @@ -303,12 +320,13 @@ def imp():
assert lr_intercept == pytest.approx(sk_intercept, abs=tolerance)

# test predict
cu_preds = lr.predict(X_df, delayed=delayed)
accuracy_cuml = accuracy_score(y, cu_preds.compute().to_numpy())
cu_preds = lr.predict(X_df, delayed=delayed).compute().to_numpy()
accuracy_cuml = accuracy_score(y, cu_preds)

sk_preds = sk_model.predict(X)
accuracy_sk = accuracy_score(y, sk_preds)

assert len(cu_preds) == len(sk_preds)
assert (accuracy_cuml >= accuracy_sk) | (
np.abs(accuracy_cuml - accuracy_sk) < 1e-3
)
Expand Down Expand Up @@ -336,3 +354,77 @@ def test_noreg(fit_intercept, client):
l1_strength, l2_strength = lr._get_qn_params()
assert l1_strength == 0.0
assert l2_strength == 0.0


def test_n_classes_small(client):
def assert_small(X, y, n_classes):
X_df, y_df = _prep_training_data(client, X, y, partitions_per_worker=1)
from cuml.dask.linear_model import LogisticRegression as cumlLBFGS_dask

lr = cumlLBFGS_dask()
lr.fit(X_df, y_df)
assert lr._num_classes == n_classes
return lr

X = np.array([(1, 2), (1, 3)], np.float32)
y = np.array([1.0, 0.0], np.float32)
lr = assert_small(X=X, y=y, n_classes=2)
assert np.array_equal(
lr.classes_.to_numpy(), np.array([0.0, 1.0], np.float32)
)

X = np.array([(1, 2), (1, 3), (1, 2.5)], np.float32)
y = np.array([1.0, 0.0, 1.0], np.float32)
lr = assert_small(X=X, y=y, n_classes=2)
assert np.array_equal(
lr.classes_.to_numpy(), np.array([0.0, 1.0], np.float32)
)

X = np.array([(1, 2), (1, 2.5), (1, 3)], np.float32)
y = np.array([1.0, 1.0, 0.0], np.float32)
lr = assert_small(X=X, y=y, n_classes=2)
assert np.array_equal(
lr.classes_.to_numpy(), np.array([0.0, 1.0], np.float32)
)

X = np.array([(1, 2), (1, 3), (1, 2.5)], np.float32)
y = np.array([10.0, 50.0, 20.0], np.float32)
lr = assert_small(X=X, y=y, n_classes=3)
assert np.array_equal(
lr.classes_.to_numpy(), np.array([10.0, 20.0, 50.0], np.float32)
)


@pytest.mark.parametrize("n_parts", [2, 23])
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("n_classes", [8])
def test_n_classes(n_parts, fit_intercept, n_classes, client):
lr = test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=n_parts,
fit_intercept=fit_intercept,
datatype=np.float32,
delayed=True,
client=client,
penalty="l2",
n_classes=n_classes,
)

assert lr._num_classes == n_classes


@pytest.mark.parametrize("penalty", ["l1", "elasticnet"])
@pytest.mark.parametrize("l1_ratio", [0.1])
def test_l1_and_elasticnet(penalty, l1_ratio, client):
X = np.array([(1, 2), (1, 3), (2, 1), (3, 1)], np.float32)
y = np.array([1.0, 1.0, 0.0, 0.0], np.float32)
X_df, y_df = _prep_training_data(client, X, y, partitions_per_worker=1)

from cuml.dask.linear_model import LogisticRegression

lr = LogisticRegression(penalty=penalty, l1_ratio=l1_ratio)
with pytest.raises(
RuntimeError, match="Currently only support 'l2' and 'none' penalty"
):
lr.fit(X_df, y_df)