Skip to content

Commit 843af94

Browse files
committed
Add random number generation utilities and refactor sampling methods
- Introduced `random_utils.hpp` and `random_utils.cpp` for random number generation. - Updated `Makevars` to include the new `random_utils` source file. - Refactored sampling methods in `cpf.cpp` and `rpf.cpp` to utilize the new `RandomGenerator` class. - Modified `helper.cpp` to replace direct calls to `rand()` with `RandomGenerator::random_index()`.
1 parent 30196b7 commit 843af94

File tree

19 files changed

+1018
-76
lines changed

19 files changed

+1018
-76
lines changed

src/Makevars

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
SOURCES=lib/cpf.cpp lib/grid.cpp lib/helper.cpp lib/rpf.cpp lib/trees.cpp lib/rcpp_interface.cpp randomPlantedForest.cpp RcppExports.cpp
1+
SOURCES=lib/cpf.cpp lib/grid.cpp lib/helper.cpp lib/rpf.cpp lib/trees.cpp lib/rcpp_interface.cpp lib/random_utils.cpp randomPlantedForest.cpp RcppExports.cpp
22

33
OBJECTS = $(SOURCES:.cpp=.o)
44

src/include/helper.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <utility>
1212
#include <thread>
1313
#include <assert.h>
14+
#include "random_utils.hpp"
1415

1516
#ifndef UTILS_H
1617
#define UTILS_H
@@ -100,12 +101,7 @@ namespace utils
100101
template <typename Iter>
101102
void shuffle_vector(Iter first, Iter last)
102103
{
103-
int n = std::distance(first, last);
104-
while (n > 1)
105-
{
106-
int k = random_index(n--);
107-
std::swap(*(first + n), *(first + k));
108-
}
104+
RandomGenerator::shuffle(first, last);
109105
};
110106

111107
std::vector<int> to_std_vec(std::vector<int> rv);

src/include/random_utils.hpp

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
#ifndef RANDOM_UTILS_H
2+
#define RANDOM_UTILS_H
3+
4+
#include <random>
5+
#include <vector>
6+
#include <algorithm>
7+
#include <type_traits>
8+
#include <chrono>
9+
#include <Rcpp.h>
10+
11+
namespace utils
12+
{
13+
14+
// Abstract base class for random number generators
15+
class RNGBackend
16+
{
17+
public:
18+
virtual ~RNGBackend() = default;
19+
virtual void begin_rng() {} // Called before a sequence of random numbers
20+
virtual void end_rng() {} // Called after a sequence of random numbers
21+
virtual int random_int(int n) = 0;
22+
virtual double random_double() = 0;
23+
};
24+
25+
// Standard C++ random number generator backend
26+
class StdRNGBackend : public RNGBackend
27+
{
28+
private:
29+
static thread_local std::mt19937 generator;
30+
static bool seeded;
31+
32+
public:
33+
void seed(uint32_t seed)
34+
{
35+
generator.seed(seed);
36+
seeded = true;
37+
}
38+
39+
void initialize()
40+
{
41+
if (!seeded)
42+
{
43+
auto now = std::chrono::high_resolution_clock::now();
44+
auto nanos = std::chrono::duration_cast<std::chrono::nanoseconds>(
45+
now.time_since_epoch())
46+
.count();
47+
generator.seed(static_cast<uint32_t>(nanos));
48+
seeded = true;
49+
}
50+
}
51+
52+
int random_int(int n) override
53+
{
54+
initialize();
55+
std::uniform_int_distribution<int> dist(0, n - 1);
56+
return dist(generator);
57+
}
58+
59+
double random_double() override
60+
{
61+
initialize();
62+
std::uniform_real_distribution<double> dist(0.0, 1.0);
63+
return dist(generator);
64+
}
65+
};
66+
67+
// R random number generator backend
68+
class RcppRNGBackend : public RNGBackend
69+
{
70+
private:
71+
std::unique_ptr<Rcpp::RNGScope> rng_scope;
72+
73+
public:
74+
void begin_rng() override
75+
{
76+
if (!rng_scope)
77+
{
78+
rng_scope = std::make_unique<Rcpp::RNGScope>();
79+
}
80+
}
81+
82+
void end_rng() override
83+
{
84+
rng_scope.reset();
85+
}
86+
87+
int random_int(int n) override
88+
{
89+
return static_cast<int>(unif_rand() * n);
90+
}
91+
92+
double random_double() override
93+
{
94+
return unif_rand();
95+
}
96+
};
97+
98+
class RandomGenerator
99+
{
100+
private:
101+
static RNGBackend *backend;
102+
static StdRNGBackend std_backend;
103+
static RcppRNGBackend rcpp_backend;
104+
105+
public:
106+
// RAII class to handle RNG state
107+
class RNGScope
108+
{
109+
private:
110+
RNGBackend *backend;
111+
112+
public:
113+
RNGScope() : backend(RandomGenerator::backend)
114+
{
115+
backend->begin_rng();
116+
}
117+
~RNGScope()
118+
{
119+
backend->end_rng();
120+
}
121+
};
122+
123+
// Switch to using R's RNG
124+
static void use_r_random()
125+
{
126+
backend = &rcpp_backend;
127+
}
128+
129+
// Switch to using C++ standard RNG
130+
static void use_std_random()
131+
{
132+
backend = &std_backend;
133+
}
134+
135+
// Initialize the standard generator with a seed
136+
static void seed(uint32_t seed)
137+
{
138+
std_backend.seed(seed);
139+
}
140+
141+
// Generate random integer in range [0, n)
142+
static int random_index(int n)
143+
{
144+
RNGScope scope;
145+
return backend->random_int(n);
146+
}
147+
148+
// Generate random double in range [0, 1)
149+
static double random_double()
150+
{
151+
RNGScope scope;
152+
return backend->random_double();
153+
}
154+
155+
// Shuffle a range of elements
156+
template <typename Iter>
157+
static void shuffle(Iter first, Iter last)
158+
{
159+
RNGScope scope;
160+
auto n = std::distance(first, last);
161+
for (auto i = n - 1; i > 0; --i)
162+
{
163+
std::swap(*(first + i), *(first + backend->random_int(i + 1)));
164+
}
165+
}
166+
167+
// Sample n elements with replacement
168+
template <typename T>
169+
static std::vector<T> sample_with_replacement(const std::vector<T> &population, size_t n)
170+
{
171+
RNGScope scope;
172+
std::vector<T> result;
173+
result.reserve(n);
174+
for (size_t i = 0; i < n; ++i)
175+
{
176+
result.push_back(population[backend->random_int(population.size())]);
177+
}
178+
return result;
179+
}
180+
};
181+
182+
} // namespace utils
183+
184+
#endif // RANDOM_UTILS_H

src/lib/cpf.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,10 +769,10 @@ Split ClassificationRPF::calcOptimalSplit(const std::vector<std::vector<double>>
769769
std::iota(samples.begin(), samples.end(), 1);
770770
}
771771
else
772-
{ // randomly picked samples otherwise
772+
{ // randomly picked samples using RandomGenerator
773773
samples = std::vector<int>(split_try);
774774
for (size_t i = 0; i < samples.size(); ++i)
775-
samples[i] = rand() % (int)(unique_samples.size() - leaf_size) + leaf_size;
775+
samples[i] = utils::RandomGenerator::random_index((int)(unique_samples.size() - leaf_size)) + leaf_size;
776776
std::sort(samples.begin(), samples.end());
777777
}
778778

src/lib/helper.cpp

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,81 @@
11
#include "helper.hpp"
2+
#include "random_utils.hpp"
23
#include <vector>
34
#include <iostream>
45
#include <set>
56

6-
namespace utils {
7+
namespace utils
8+
{
79

8-
// Helper function to generate random number using standard C++ libraries
9-
int random_index(const int n) { return static_cast<int>(static_cast<double>(rand()) / RAND_MAX * n); }
10+
// Helper function to generate random number using our RandomGenerator
11+
int random_index(const int n) { return RandomGenerator::random_index(n); }
1012

11-
// ----------------- functions for converting Cpp types -----------------
13+
// ----------------- functions for converting Cpp types -----------------
1214

15+
/**
16+
* \brief Convert the std container set of type int into a std::vector<int>.
17+
*
18+
* \param v the set that is converted.
19+
*/
20+
std::vector<int> from_std_set(std::set<int> v)
21+
{
22+
return std::vector<int>(v.begin(), v.end());
23+
}
1324

14-
/**
15-
* \brief Convert the std container set of type int into a std::vector<int>.
16-
*
17-
* \param v the set that is converted.
18-
*/
19-
std::vector<int> from_std_set(std::set<int> v) {
20-
return std::vector<int>(v.begin(), v.end());
21-
}
25+
/**
26+
* \brief Convert the std container vector of type int into a std::vector<int>.
27+
*
28+
* \param v the vector that is converted.
29+
*/
30+
std::vector<int> from_std_vec(std::vector<int> v)
31+
{
32+
return v;
33+
}
2234

23-
/**
24-
* \brief Convert the std container vector of type int into a std::vector<int>.
25-
*
26-
* \param v the vector that is converted.
27-
*/
28-
std::vector<int> from_std_vec(std::vector<int> v) {
29-
return v;
30-
}
35+
/**
36+
* \brief Convert the std container vector of type double into a std::vector<double>.
37+
*
38+
* \param v the vector that is converted.
39+
*/
40+
std::vector<double> from_std_vec(std::vector<double> v)
41+
{
42+
return v;
43+
}
3144

32-
/**
33-
* \brief Convert the std container vector of type double into a std::vector<double>.
34-
*
35-
* \param v the vector that is converted.
36-
*/
37-
std::vector<double> from_std_vec(std::vector<double> v) {
38-
return v;
39-
}
40-
41-
/**
42-
* \brief Convert the nested std container vector containing a vector itself
43-
* of type double into a std::vector<std::vector<double>>.
44-
*
45-
* \param v the vector of vectors that is converted.
46-
*/
47-
std::vector<std::vector<double>> from_std_vec(std::vector<std::vector<double>> v) {
48-
return v;
49-
}
45+
/**
46+
* \brief Convert the nested std container vector containing a vector itself
47+
* of type double into a std::vector<std::vector<double>>.
48+
*
49+
* \param v the vector of vectors that is converted.
50+
*/
51+
std::vector<std::vector<double>> from_std_vec(std::vector<std::vector<double>> v)
52+
{
53+
return v;
54+
}
5055

51-
std::vector<int> to_std_vec(std::vector<int> rv) {
52-
return rv;
53-
}
56+
std::vector<int> to_std_vec(std::vector<int> rv)
57+
{
58+
return rv;
59+
}
5460

55-
std::vector<double> to_std_vec(std::vector<double> rv) {
56-
return rv;
57-
}
61+
std::vector<double> to_std_vec(std::vector<double> rv)
62+
{
63+
return rv;
64+
}
5865

59-
std::vector<std::vector<double>> to_std_vec(std::vector<std::vector<double>> rv) {
60-
return rv;
61-
}
66+
std::vector<std::vector<double>> to_std_vec(std::vector<std::vector<double>> rv)
67+
{
68+
return rv;
69+
}
6270

63-
std::set<int> to_std_set(std::vector<int> rv) {
64-
return std::set<int>(rv.begin(), rv.end());
65-
}
71+
std::set<int> to_std_set(std::vector<int> rv)
72+
{
73+
return std::set<int>(rv.begin(), rv.end());
74+
}
6675

67-
std::set<int> to_std_set(std::vector<double> rv) {
68-
return std::set<int>(rv.begin(), rv.end());
69-
}
76+
std::set<int> to_std_set(std::vector<double> rv)
77+
{
78+
return std::set<int>(rv.begin(), rv.end());
79+
}
7080

7181
}

src/lib/random_utils.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#include "random_utils.hpp"
2+
3+
namespace utils
4+
{
5+
thread_local std::mt19937 StdRNGBackend::generator;
6+
bool StdRNGBackend::seeded = false;
7+
RNGBackend *RandomGenerator::backend = &RandomGenerator::std_backend;
8+
StdRNGBackend RandomGenerator::std_backend;
9+
RcppRNGBackend RandomGenerator::rcpp_backend;
10+
}

src/lib/rcpp_interface.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "rcpp_interface.hpp"
2+
#include "random_utils.hpp"
23

34
// Helper to convert NumericMatrix to std::vector<std::vector<double>>
45
static std::vector<std::vector<double>> toStd2D(const Rcpp::NumericMatrix &mat)
@@ -25,6 +26,7 @@ RcppRPF::RcppRPF(const NumericMatrix &samples_Y, const NumericMatrix &samples_X,
2526
: RandomPlantedForest(toStd2D(samples_Y), toStd2D(samples_X),
2627
toStd1D(parameters))
2728
{
29+
utils::RandomGenerator::use_r_random();
2830
}
2931

3032
NumericMatrix RcppRPF::predict_matrix(const NumericMatrix &X, const NumericVector components)
@@ -105,9 +107,9 @@ bool RcppRPF::is_purified()
105107

106108
RcppCPF::RcppCPF(const NumericMatrix &samples_Y, const NumericMatrix &samples_X,
107109
const std::string loss, const NumericVector parameters)
108-
: ClassificationRPF(toStd2D(samples_Y), toStd2D(samples_X), loss,
109-
toStd1D(parameters))
110+
: ClassificationRPF(toStd2D(samples_Y), toStd2D(samples_X), loss, toStd1D(parameters))
110111
{
112+
utils::RandomGenerator::use_r_random();
111113
}
112114

113115
void RcppCPF::set_parameters(StringVector keys, NumericVector values)

0 commit comments

Comments
 (0)