Skip to content

Commit 7046d76

Browse files
committed
Utils: Added simple Adam gradient descent implementation.
This is a simple implementation of Adam. It tries to be as straightforward as possible, and extremely cheap to call and update (although not to build given the internal Vector allocation required for the moment estimates). It currently doesn't have much use in the library, but hopefully when I get around to pushing GPs it will make more sense. In the meantime it cannot hurt to have it around in Utils.
1 parent 10ec7da commit 7046d76

File tree

5 files changed

+214
-0
lines changed

5 files changed

+214
-0
lines changed

include/AIToolbox/Utils/Adam.hpp

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#ifndef AI_TOOLBOX_ADAM_HEADER_FILE
2+
#define AI_TOOLBOX_ADAM_HEADER_FILE
3+
4+
#include <AIToolbox/Types.hpp>
5+
6+
namespace AIToolbox {
7+
/**
8+
* @brief This class implements the ADAM gradient descent algorithm.
9+
*
10+
* This class keeps things simple and fast. It takes two pointers to two
11+
* equally-sized vectors; one used to track the currently examined point,
12+
* and the other to provide Adam with the gradient.
13+
*
14+
* This class expects you to compute the gradient of the currently examined
15+
* point. At each step() call, the point vector is updated following the
16+
* gradient using the Adam algorithm.
17+
*
18+
* We take pointers rather than references so that the pointers can be
19+
* updated as needed, while the class instance kept around. This only works
20+
* if the new vectors have the same size as before, but it allows to avoid
21+
* reallocation of the internal helper vectors.
22+
*/
23+
class Adam {
24+
public:
25+
/**
26+
* @brief Basic constructor.
27+
*
28+
* We expect the pointers to not be null, and the vectors to be preallocated.
29+
*
30+
* The point vector should contain the point where to start the
31+
* gradient descent process. The gradient vector should contain
32+
* the gradient at that point.
33+
*
34+
* @param point A pointer to preallocated space where to write the point.
35+
* @param gradient A pointer to preallocated space containing the current gradient.
36+
* @param alpha Adam's step size/learning rate.
37+
* @param beta1 Adam's exponential decay rate for first moment estimates.
38+
* @param beta2 Adam's exponential decay rate for second moment estimates.
39+
* @param epsilon Additive parameter to prevent division by zero.
40+
*/
41+
Adam(AIToolbox::Vector * point, const AIToolbox::Vector * gradient, double alpha = 0.001, double beta1 = 0.9, double beta2 = 0.999, double epsilon = 1e-8);
42+
43+
/**
44+
* @brief This function updates the point using the currently set gradient.
45+
*
46+
* This function overwrites the vector pointed by the `point`
47+
* pointer, by following the currently set gradient.
48+
*
49+
* It is expected that the gradient is correct and has been updated
50+
* by the user before calling this function.
51+
*/
52+
void step();
53+
54+
/**
55+
* @brief This function resets the gradient descent process.
56+
*
57+
* This function clears all internal values so that the gradient
58+
* descent process can be restarted from scratch.
59+
*
60+
* The point vector is not modified.
61+
*/
62+
void reset();
63+
64+
/**
65+
* @brief This function resets the gradient descent process.
66+
*
67+
* This function clears all internal values so that the gradient
68+
* descent process can be restarted from scratch.
69+
*
70+
* The point and gradient pointers are updated with the new inputs.
71+
*/
72+
void reset(AIToolbox::Vector * point, const AIToolbox::Vector * gradient);
73+
74+
/**
75+
* @brief This function sets the current learning rate.
76+
*/
77+
void setAlpha(double alpha);
78+
79+
/**
80+
* @brief This function sets the current exponential decay rate for first moment estimates.
81+
*/
82+
void setBeta1(double beta1);
83+
84+
/**
85+
* @brief This function sets the current exponential decay rate for second moment estimates.
86+
*/
87+
void setBeta2(double beta2);
88+
89+
/**
90+
* @brief This function sets the current additive division parameter.
91+
*/
92+
void setEpsilon(double epsilon);
93+
94+
/**
95+
* @brief This function returns the current learning rate.
96+
*/
97+
double getAlpha() const;
98+
99+
/**
100+
* @brief This function returns the current exponential decay rate for first moment estimates.
101+
*/
102+
double getBeta1() const;
103+
104+
/**
105+
* @brief This function returns the current exponential decay rate for second moment estimates.
106+
*/
107+
double getBeta2() const;
108+
109+
/**
110+
* @brief This function returns the current additive division parameter.
111+
*/
112+
double getEpsilon() const;
113+
114+
private:
115+
AIToolbox::Vector * point_;
116+
const AIToolbox::Vector * gradient_;
117+
AIToolbox::Vector m_, v_;
118+
119+
double beta1_, beta2_, alpha_, epsilon_;
120+
unsigned step_;
121+
};
122+
}
123+
124+
#endif

src/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ if (MAKE_MDP)
44
add_library(AIToolboxMDP
55
Impl/Seeder.cpp
66
Impl/CassandraParser.cpp
7+
Utils/Adam.cpp
78
Utils/Combinatorics.cpp
89
Utils/Probability.cpp
910
Utils/Polytope.cpp

src/Utils/Adam.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include <AIToolbox/Utils/Adam.hpp>
2+
3+
namespace AIToolbox {
4+
Adam::Adam(AIToolbox::Vector * point, const AIToolbox::Vector * gradient, const double alpha, const double beta1, const double beta2, const double epsilon) :
5+
point_(point), gradient_(gradient),
6+
m_(point_->size()), v_(point_->size()),
7+
beta1_(beta1), beta2_(beta2), alpha_(alpha), epsilon_(epsilon),
8+
step_(1)
9+
{
10+
reset();
11+
}
12+
13+
void Adam::step() {
14+
assert(point);
15+
assert(gradient);
16+
17+
m_ = beta1_ * m_ + (1.0 - beta1_) * (*gradient_);
18+
v_ = beta2_ * v_ + (1.0 - beta2_) * (*gradient_).array().square().matrix();
19+
20+
const double alphaHat = alpha_ * std::sqrt(1.0 - std::pow(beta2_, step_)) / (1.0 - std::pow(beta1_, step_));
21+
22+
(*point_).array() -= alphaHat * m_.array() / (v_.array().sqrt() + epsilon_);
23+
24+
++step_;
25+
}
26+
27+
void Adam::reset() {
28+
m_.fill(0.0);
29+
v_.fill(0.0);
30+
step_ = 1;
31+
}
32+
33+
void Adam::reset(AIToolbox::Vector * point, const AIToolbox::Vector * gradient) {
34+
point_ = point;
35+
gradient_ = gradient;
36+
reset();
37+
}
38+
39+
void Adam::setBeta1(double beta1) { beta1_ = beta1; }
40+
void Adam::setBeta2(double beta2) { beta2_ = beta2; }
41+
void Adam::setAlpha(double alpha) { alpha_ = alpha; }
42+
void Adam::setEpsilon(double epsilon) { epsilon_ = epsilon; }
43+
44+
double Adam::getBeta1() const { return beta1_; }
45+
double Adam::getBeta2() const { return beta2_; }
46+
double Adam::getAlpha() const { return alpha_; }
47+
double Adam::getEpsilon() const { return epsilon_; }
48+
}

test/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.9) # CMP0069 NEW
33
set(GlobalFileDependencies
44
${PROJECT_SOURCE_DIR}/src/Impl/Seeder.cpp
55
${PROJECT_SOURCE_DIR}/src/Tools/Statistics.cpp
6+
${PROJECT_SOURCE_DIR}/src/Utils/Adam.cpp
67
${PROJECT_SOURCE_DIR}/src/Utils/Combinatorics.cpp
78
${PROJECT_SOURCE_DIR}/src/Utils/Probability.cpp
89
${PROJECT_SOURCE_DIR}/src/Utils/LP/LpSolveWrapper.cpp
@@ -39,6 +40,7 @@ function (AddTestPython type name)
3940
endfunction (AddTestPython)
4041

4142
if (MAKE_MDP)
43+
AddTestGlobal(UtilsAdam)
4244
AddTestGlobal(UtilsCore)
4345
AddTestGlobal(UtilsProbability)
4446
AddTestGlobal(UtilsPrune)

test/UtilsAdamTests.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#define BOOST_TEST_MODULE UtilsAdam
2+
#define BOOST_TEST_DYN_LINK
3+
#define BOOST_TEST_MAIN
4+
#include <boost/test/unit_test.hpp>
5+
6+
#include <AIToolbox/Types.hpp>
7+
#include <AIToolbox/Utils/Adam.hpp>
8+
9+
namespace ai = AIToolbox;
10+
11+
double objective(const ai::Vector & p) {
12+
return p.squaredNorm();
13+
}
14+
15+
void derivative(const ai::Vector & p, ai::Vector & grad) {
16+
grad[0] = 2 * p[0];
17+
grad[1] = 2 * p[1];
18+
}
19+
20+
BOOST_AUTO_TEST_CASE( simple_gradient_descent ) {
21+
using namespace AIToolbox;
22+
23+
ai::Vector point(2);
24+
point << -0.21, 0.47;
25+
26+
ai::Vector gradient(2);
27+
derivative(point, gradient);
28+
29+
ai::Adam adam(&point, &gradient, 0.02);
30+
31+
for (auto i = 0; i < 100; ++i) {
32+
adam.step();
33+
derivative(point, gradient);
34+
}
35+
36+
double val = objective(point);
37+
BOOST_TEST_INFO(val);
38+
BOOST_CHECK(val < 1e-5);
39+
}

0 commit comments

Comments
 (0)