Skip to content

Commit 5ab8ee6

Browse files
authored
[7.7][ML] Multinomial logistic regression (#1053)
Backport #1037.
1 parent 8e43ba6 commit 5ab8ee6

21 files changed

+1633
-551
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
* Add instrumentation to report statistics related to data frame analytics jobs, i.e.
3636
progress, memory usage, etc. (See {ml-pull}906[#906].)
37+
* Multiclass classification. (See {ml-pull}1037[#1037].)
3738

3839
=== Enhancements
3940

include/maths/CBoostedTreeLoss.h

Lines changed: 131 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
#define INCLUDED_ml_maths_CBoostedTreeLoss_h
99

1010
#include <maths/CBasicStatistics.h>
11+
#include <maths/CKMeansOnline.h>
1112
#include <maths/CLinearAlgebra.h>
1213
#include <maths/CLinearAlgebraEigen.h>
14+
#include <maths/CPRNG.h>
1315
#include <maths/ImportExport.h>
1416
#include <maths/MathsTypes.h>
1517

@@ -66,9 +68,26 @@ class MATHS_EXPORT CArgMinMseImpl final : public CArgMinLossImpl {
6668

6769
//! \brief Finds the value to add to a set of predicted log-odds which minimises
6870
//! regularised cross entropy loss w.r.t. the actual categories.
69-
class MATHS_EXPORT CArgMinLogisticImpl final : public CArgMinLossImpl {
71+
//!
72+
//! DESCRIPTION:\n
73+
//! We want to find the weight which minimizes the log-loss, i.e. which satisfies
74+
//! <pre class="fragment">
75+
//! \f$\displaystyle arg\min_w{ \lambda w^2 -\sum_i{ a_i \log(S(p_i + w)) + (1 - a_i) \log(1 - S(p_i + w)) } }\f$
76+
//! </pre>
77+
//!
78+
//! Rather than working with this function directly we bucket the predictions `p_i`
79+
//! in a first pass over the data and compute weight which minimizes the approximate
80+
//! function
81+
//! <pre class="fragment">
82+
//! \f$\displaystyle arg\min_w{ \lambda w^2 -\sum_{B}{ c_{1,B} \log(S(\bar{p}_B + w)) + c_{0,B} \log(1 - S(\bar{p}_B + w)) } }\f$
83+
//! </pre>
84+
//!
85+
//! Here, \f$B\f$ ranges over the buckets, \f$\bar{p}_B\f$ denotes the B'th bucket
86+
//! centre and \f$c_{0,B}\f$ and \f$c_{1,B}\f$ denote the counts of actual classes
87+
//! 0 and 1, respectively, in the bucket \f$B\f$.
88+
class MATHS_EXPORT CArgMinBinomialLogisticLossImpl final : public CArgMinLossImpl {
7089
public:
71-
CArgMinLogisticImpl(double lambda);
90+
CArgMinBinomialLogisticLossImpl(double lambda);
7291
std::unique_ptr<CArgMinLossImpl> clone() const override;
7392
bool nextPass() override;
7493
void add(const TMemoryMappedFloatVector& prediction, double actual, double weight = 1.0) override;
@@ -80,11 +99,13 @@ class MATHS_EXPORT CArgMinLogisticImpl final : public CArgMinLossImpl {
8099
using TDoubleVector2x1 = CVectorNx1<double, 2>;
81100
using TDoubleVector2x1Vec = std::vector<TDoubleVector2x1>;
82101

102+
private:
103+
static constexpr std::size_t NUMBER_BUCKETS = 128;
104+
83105
private:
84106
std::size_t bucket(double prediction) const {
85107
double bucket{(prediction - m_PredictionMinMax.min()) / this->bucketWidth()};
86-
return std::min(static_cast<std::size_t>(bucket),
87-
m_BucketCategoryCounts.size() - 1);
108+
return std::min(static_cast<std::size_t>(bucket), m_BucketsClassCounts.size() - 1);
88109
}
89110

90111
double bucketCentre(std::size_t bucket) const {
@@ -95,15 +116,74 @@ class MATHS_EXPORT CArgMinLogisticImpl final : public CArgMinLossImpl {
95116
double bucketWidth() const {
96117
return m_PredictionMinMax.initialized()
97118
? m_PredictionMinMax.range() /
98-
static_cast<double>(m_BucketCategoryCounts.size())
119+
static_cast<double>(m_BucketsClassCounts.size())
99120
: 0.0;
100121
}
101122

102123
private:
103124
std::size_t m_CurrentPass = 0;
104125
TMinMaxAccumulator m_PredictionMinMax;
105-
TDoubleVector2x1 m_CategoryCounts;
106-
TDoubleVector2x1Vec m_BucketCategoryCounts;
126+
TDoubleVector2x1 m_ClassCounts;
127+
TDoubleVector2x1Vec m_BucketsClassCounts;
128+
};
129+
130+
//! \brief Finds the value to add to a set of predicted multinomial logit which
131+
//! minimises regularised cross entropy loss w.r.t. the actual classes.
132+
//!
133+
//! DESCRIPTION:\n
134+
//! We want to find the weight which minimizes the log-loss, i.e. which satisfies
135+
//! <pre class="fragment">
136+
//! \f$\displaystyle arg\min_w{ \lambda \|w\|^2 -\sum_i{ \log([softmax(p_i + w)]_{a_i}) } }\f$
137+
//! </pre>
138+
//!
139+
//! Here, \f$a_i\f$ is the index of the i'th example's true class. Rather than
140+
//! working with this function directly we approximate it by the means and count
141+
//! of predictions in a partition of the original data, i.e. we compute the weight
142+
//! weight which satisfies
143+
//! <pre class="fragment">
144+
//! \f$\displaystyle arg\min_w{ \lambda \|w\|^2 -\sum_P{ c_{a_i, P} \log([softmax(\bar{p}_P + w)]) } }\f$
145+
//! </pre>
146+
//!
147+
//! Here, \f$P\f$ ranges over the subsets of the partition, \f$\bar{p}_P\f$ denotes
148+
//! the mean of the predictions in the P'th subset and \f$c_{a_i, P}\f$ denote the
149+
//! counts of each classes \f$\{a_i\}\f$ in the subset \f$P\f$. We compute this
150+
//! partition by k-means.
151+
class MATHS_EXPORT CArgMinMultinomialLogisticLossImpl final : public CArgMinLossImpl {
152+
public:
153+
using TObjective = std::function<double(const TDoubleVector&)>;
154+
using TObjectiveGradient = std::function<TDoubleVector(const TDoubleVector&)>;
155+
156+
public:
157+
CArgMinMultinomialLogisticLossImpl(std::size_t numberClasses,
158+
double lambda,
159+
const CPRNG::CXorOShiro128Plus& rng);
160+
std::unique_ptr<CArgMinLossImpl> clone() const override;
161+
bool nextPass() override;
162+
void add(const TMemoryMappedFloatVector& prediction, double actual, double weight = 1.0) override;
163+
void merge(const CArgMinLossImpl& other) override;
164+
TDoubleVector value() const override;
165+
166+
// Exposed for unit testing.
167+
TObjective objective() const;
168+
TObjectiveGradient objectiveGradient() const;
169+
170+
private:
171+
using TDoubleVectorVec = std::vector<TDoubleVector>;
172+
using TKMeans = CKMeansOnline<TDoubleVector>;
173+
174+
private:
175+
static constexpr std::size_t NUMBER_CENTRES = 128;
176+
static constexpr std::size_t NUMBER_RESTARTS = 5;
177+
178+
private:
179+
std::size_t m_NumberClasses = 0;
180+
std::size_t m_CurrentPass = 0;
181+
mutable CPRNG::CXorOShiro128Plus m_Rng;
182+
TDoubleVector m_ClassCounts;
183+
TDoubleVector m_DoublePrediction;
184+
TKMeans m_PredictionSketch;
185+
TDoubleVectorVec m_Centres;
186+
TDoubleVectorVec m_CentresClassCounts;
107187
};
108188
}
109189

@@ -185,7 +265,8 @@ class MATHS_EXPORT CLoss {
185265
//! Transforms a prediction from the forest to the target space.
186266
virtual TDoubleVector transform(const TMemoryMappedFloatVector& prediction) const = 0;
187267
//! Get an object which computes the leaf value that minimises loss.
188-
virtual CArgMinLoss minimizer(double lambda) const = 0;
268+
virtual CArgMinLoss minimizer(double lambda,
269+
const CPRNG::CXorOShiro128Plus& rng) const = 0;
189270
//! Get the name of the loss function
190271
virtual const std::string& name() const = 0;
191272

@@ -214,7 +295,7 @@ class MATHS_EXPORT CMse final : public CLoss {
214295
double weight = 1.0) const override;
215296
bool isCurvatureConstant() const override;
216297
TDoubleVector transform(const TMemoryMappedFloatVector& prediction) const override;
217-
CArgMinLoss minimizer(double lambda) const override;
298+
CArgMinLoss minimizer(double lambda, const CPRNG::CXorOShiro128Plus& rng) const override;
218299
const std::string& name() const override;
219300
};
220301

@@ -227,11 +308,47 @@ class MATHS_EXPORT CMse final : public CLoss {
227308
//! </pre>
228309
//! where \f$a_i\f$ denotes the actual class of the i'th example, \f$p\f$ is the
229310
//! prediction and \f$S(\cdot)\f$ denotes the logistic function.
230-
class MATHS_EXPORT CBinomialLogistic final : public CLoss {
311+
class MATHS_EXPORT CBinomialLogisticLoss final : public CLoss {
312+
public:
313+
static const std::string NAME;
314+
315+
public:
316+
std::unique_ptr<CLoss> clone() const override;
317+
std::size_t numberParameters() const override;
318+
double value(const TMemoryMappedFloatVector& prediction,
319+
double actual,
320+
double weight = 1.0) const override;
321+
void gradient(const TMemoryMappedFloatVector& prediction,
322+
double actual,
323+
TWriter writer,
324+
double weight = 1.0) const override;
325+
void curvature(const TMemoryMappedFloatVector& prediction,
326+
double actual,
327+
TWriter writer,
328+
double weight = 1.0) const override;
329+
bool isCurvatureConstant() const override;
330+
TDoubleVector transform(const TMemoryMappedFloatVector& prediction) const override;
331+
CArgMinLoss minimizer(double lambda, const CPRNG::CXorOShiro128Plus& rng) const override;
332+
const std::string& name() const override;
333+
};
334+
335+
//! \brief Implements loss for multinomial logistic regression.
336+
//!
337+
//! DESCRIPTION:\n
338+
//! This targets the cross-entropy loss using the forest to predict the class
339+
//! probabilities via the softmax function:
340+
//! <pre class="fragment">
341+
//! \f$\displaystyle l_i(p) = -\sum_i a_{ij} \log(\sigma(p))\f$
342+
//! </pre>
343+
//! where \f$a_i\f$ denotes the actual class of the i'th example, \f$p\f$ denotes
344+
//! the vector valued prediction and \f$\sigma(p)\$ is the softmax function, i.e.
345+
//! \f$[\sigma(p)]_j = \frac{e^{p_i}}{\sum_k e^{p_k}}\f$.
346+
class MATHS_EXPORT CMultinomialLogisticLoss final : public CLoss {
231347
public:
232348
static const std::string NAME;
233349

234350
public:
351+
CMultinomialLogisticLoss(std::size_t numberClasses);
235352
std::unique_ptr<CLoss> clone() const override;
236353
std::size_t numberParameters() const override;
237354
double value(const TMemoryMappedFloatVector& prediction,
@@ -247,8 +364,11 @@ class MATHS_EXPORT CBinomialLogistic final : public CLoss {
247364
double weight = 1.0) const override;
248365
bool isCurvatureConstant() const override;
249366
TDoubleVector transform(const TMemoryMappedFloatVector& prediction) const override;
250-
CArgMinLoss minimizer(double lambda) const override;
367+
CArgMinLoss minimizer(double lambda, const CPRNG::CXorOShiro128Plus& rng) const override;
251368
const std::string& name() const override;
369+
370+
private:
371+
std::size_t m_NumberClasses;
252372
};
253373
}
254374
}

include/maths/CKMeans.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <boost/iterator/counting_iterator.hpp>
2222

2323
#include <cstddef>
24+
#include <cstdint>
2425
#include <sstream>
2526
#include <utility>
2627
#include <vector>
@@ -125,15 +126,15 @@ class CKMeans {
125126
const TPointVec& points() const { return m_Points; }
126127

127128
//! Get the cluster checksum.
128-
uint64_t checksum() const { return m_Checksum; }
129+
std::uint64_t checksum() const { return m_Checksum; }
129130

130131
private:
131132
//! The centroid of the points in this cluster.
132133
POINT m_Centre;
133134
//! The points in the cluster.
134135
TPointVec m_Points;
135136
//! A checksum for the points in the cluster.
136-
uint64_t m_Checksum;
137+
std::uint64_t m_Checksum;
137138
};
138139

139140
using TClusterVec = std::vector<CCluster>;
@@ -183,8 +184,9 @@ class CKMeans {
183184
if (m_Centres.empty()) {
184185
return true;
185186
}
187+
TMeanAccumulatorVec newCentres;
186188
for (std::size_t i = 0u; i < maxIterations; ++i) {
187-
if (!this->updateCentres()) {
189+
if (!this->updateCentres(newCentres)) {
188190
return true;
189191
}
190192
}
@@ -481,19 +483,19 @@ class CKMeans {
481483

482484
private:
483485
//! Single iteration of Lloyd's algorithm to update \p centres.
484-
bool updateCentres() {
485-
const TCoordinate precision = TCoordinate(5) *
486-
std::numeric_limits<TCoordinate>::epsilon();
487-
TMeanAccumulatorVec newCentres(m_Centres.size(),
488-
TMeanAccumulator(las::zero(m_Centres[0])));
486+
bool updateCentres(TMeanAccumulatorVec& newCentres) {
487+
const TCoordinate precision{TCoordinate(5) *
488+
std::numeric_limits<TCoordinate>::epsilon()};
489+
newCentres.assign(m_Centres.size(), TMeanAccumulator(las::zero(m_Centres[0])));
489490
CCentroidComputer computer(m_Centres, newCentres);
490491
m_Points.preorderDepthFirst(computer);
491492
bool changed = false;
493+
POINT newCentre;
492494
for (std::size_t i = 0u; i < newCentres.size(); ++i) {
493-
POINT newCentre(CBasicStatistics::mean(newCentres[i]));
495+
newCentre = CBasicStatistics::mean(newCentres[i]);
494496
if (las::distance(m_Centres[i], newCentre) >
495497
precision * las::norm(m_Centres[i])) {
496-
m_Centres[i] = newCentre;
498+
las::swap(m_Centres[i], newCentre);
497499
changed = true;
498500
}
499501
}

0 commit comments

Comments
 (0)