Skip to content

Mm/pure gradient descent #419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/sopt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
set(headers
bisection_method.h chained_operators.h credible_region.h
imaging_padmm.h logging.disabled.h
gradient_descent.h
forward_backward.h imaging_forward_backward.h
g_proximal.h l1_g_proximal.h joint_map.h
imaging_primal_dual.h primal_dual.h
Expand Down
128 changes: 128 additions & 0 deletions cpp/sopt/gradient_descent.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#ifndef SOPT_GRADIENT_DESCENT_H
#define SOPT_GRADIENT_DESCENT_H

#include <functional>
#include "sopt/linear_transform.h"
#include "sopt/types.h"

namespace sopt::algorithm {

//! Values indicating how the algorithm ran
template <typename SCALAR>
struct AlgorithmResults {
//! Number of iterations
t_uint niters;
//! Wether convergence was achieved
bool good;
//! the residual from the last iteration
Vector<SCALAR> residual;
Vector<SCALAR> result;
};

//! \brief Pure gradient descent algorithm
//! \details Requires \f$\grad f, \grad g\f$ be analytically defined.
//! \f$x_{n+1} = x_n + \alpha R(\grad f(x_n, y)) + \lambda \grad(g(\mu x_n))\f$
//! \param f_gradient: Gradient function for f, where f is usually a likelihood. Takes two arguments(x, y).
//! \param g_gradient: Gradient function for g, where g is usually a prior / regulator. Takes one argument (x).
//! \param lambda: multiplier for g gradient function
//! \param Lipschitz_f: Lipschitz constant of function f (used to calculated alpha)
//! \param Lipschitz_g: Lipschitz constant of function g (used to calculated alpha)
//! \param mu: Scaling parameter for vector inside g. Also used to calculate alpha
template <typename SCALAR>
class GradientDescent
{
public:
using F_Gradient =
typename std::function<Vector<SCALAR>(const Vector<SCALAR> &, const Vector<SCALAR> &)>;
using G_Gradient = typename std::function<Vector<SCALAR>(const Vector<SCALAR> &)>;
using REAL = typename real_type<SCALAR>::type;

GradientDescent(F_Gradient const &f_gradient,
G_Gradient const &g_gradient,
Vector<SCALAR> const &target,
REAL const threshold,
REAL const Lipschitz_f = 1,
REAL const Lipschitz_g = 1,
REAL const mu = 1,
REAL const lambda = 1)
: Phi(linear_transform_identity<SCALAR>()),
f_gradient(f_gradient),
g_gradient(g_gradient),
target(target),
Lipschitz_f(Lipschitz_f),
Lipschitz_g(Lipschitz_g),
threshold_delta(threshold)
{
alpha = 0.98 / (Lipschitz_f + mu * lambda * Lipschitz_g);
}

AlgorithmResults<SCALAR> operator()(Vector<SCALAR> &x)
{
Vector<SCALAR> z = x;
bool converged = false;
uint iterations = 0;
while ((!converged) && (iterations < max_iterations))
{
iteration_step(x, z);

converged = is_converged(x);

++iterations;
}

if(converged)
{
// TODO: Log some relevant stuff about the convergence.
}

AlgorithmResults<SCALAR> results;
results.good = converged;
results.niters = iterations;
results.residual = (Phi * x) - target;
results.result = z;

return results;
}

protected:
LinearTransform<Vector<SCALAR>> Phi;
F_Gradient f_gradient;
G_Gradient g_gradient;
REAL alpha;
REAL lambda = 1;
REAL mu = 1;
REAL Lipschitz_f = 1;
REAL Lipschitz_g = 1;
Vector<SCALAR> target;
REAL threshold_delta;
Vector<SCALAR> delta_x;
REAL theta_now;
REAL theta_next;
Vector<SCALAR> x_prev;
uint max_iterations = 200;

void iteration_step(Vector<SCALAR> &x, Vector<SCALAR> &z)
{
// Should be able to make this better to avoid copies
x_prev = x;

delta_x = f_gradient(z, target).real();
delta_x += lambda * g_gradient(mu * z);
delta_x *= alpha;

theta_next = 0.5 * (1 + sqrt(1 + 4*theta_now*theta_now));

x = z - delta_x;
z = x + (theta_now - 1)/ theta_next * (x - x_prev);
}

bool is_converged(Vector<SCALAR> &x)
{
return (delta_x.norm() / x.norm()) < threshold_delta;
}

};

} // namespace sopt::algorithm

#endif // SOPT_GRADIENT_DESCENT_H
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_catch_test(chained_operators LIBRARIES sopt SEED ${RAND_SEED})
add_catch_test(conjugate_gradient LIBRARIES sopt SEED ${RAND_SEED})
add_catch_test(credible_region LIBRARIES sopt SEED ${RAND_SEED})
add_catch_test(forward_backward LIBRARIES sopt tools_for_tests SEED ${RAND_SEED})
add_catch_test(gradient_descent LIBRARIES sopt tools_for_tests SEED ${RAND_SEED})
add_catch_test(gradient_operator LIBRARIES sopt tools_for_tests SEED ${RAND_SEED})
add_catch_test(inpainting LIBRARIES sopt tools_for_tests SEED ${RAND_SEED})
add_catch_test(linear_transform LIBRARIES sopt SEED ${RAND_SEED})
Expand Down
79 changes: 79 additions & 0 deletions cpp/tests/gradient_descent.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include <catch2/catch_all.hpp>
#include "sopt/gradient_descent.h"
#include <random>

uint constexpr N = 10;

TEST_CASE("Gradient Descent with flat prior", "[GradDescent]")
{
using namespace sopt;

const Vector<float> target = Vector<float>::Random(N);
float const sigma = 0.5;
float const gamma = 0.1;
uint const max_iterations = 100;

auto const grad_likelihood = [](const Vector<float>&x, const Vector<float>&y){return (x-y);};
auto const grad_prior = [](const Vector<float> &x){return 0*x;};

Vector<float> init_guess = Vector<float>::Random(N);

auto Phi = linear_transform_identity<float>();

algorithm::GradientDescent<float> gd(grad_likelihood, grad_prior, target, 1e-4);

algorithm::AlgorithmResults<float> results = gd(init_guess);

CHECK(results.good);
CHECK(results.result.isApprox(target, 0.1));
}

TEST_CASE("Gradient Descent with smoothness prior", "[GradDescent]")
{
using namespace sopt;
std::mt19937_64 rng;
std::uniform_real_distribution<float> noise(0, 0.2);

Vector<float> data(N);
for(uint i = 0; i < N; i++)
{
data(i) = sin((M_PI/(N-1))*i) + noise(rng);
}

Vector<float> perfect(N);
for(uint i = 0; i < N; i++)
{
perfect(i) = sin((M_PI/(N-1))*i);
}

float const sigma = 0.5;
float const gamma = 0.1;
uint const max_iterations = 100;

auto const grad_likelihood = [](const Vector<float>&x, const Vector<float>&y){return (x-y);};
auto const grad_prior = [](const Vector<float> &x)
{
Vector<float> grad(x.size());
grad(0) = x(0);
grad(x.size()-1) = x(x.size()-1);
for(uint i = 1; i < x.size()-1; i++)
{
// Push values to be roughly in line with neighbours
// Hand wavey kind of smoothness prior
grad(i) = x(i) - 0.5*(x(i-1) + x(i+1));
}
return grad;
};

Vector<float> init_guess = Vector<float>::Random(N);

auto Phi = linear_transform_identity<float>();

algorithm::GradientDescent<float> gd(grad_likelihood, grad_prior, data, 1e-4);

algorithm::AlgorithmResults<float> results = gd(init_guess);

CHECK(results.good);
CHECK(results.result.isApprox(perfect, 0.1));
CHECK((results.result - perfect).squaredNorm() < (data-perfect).squaredNorm());
}