-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOptimizer.h
44 lines (36 loc) · 924 Bytes
/
Optimizer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// Oliver - machine learning library.
// written by cubeflix - https://github.com/cubeflix/oliver
//
// Optimizer.h
// Base optimizer and settings class.
#pragma once
#include "Matrix.h"
namespace Oliver {
class Optimizer {
public:
virtual ~Optimizer() = 0;
virtual void update(int device) = 0;
};
class OptimizerSettings {
public:
virtual Optimizer* create(Matrix* x, Matrix* xGrad) = 0;
};
class SGDOptimizerSettings : public OptimizerSettings {
public:
SGDOptimizerSettings(const float learningRate);
Optimizer* create(Matrix* x, Matrix* xGrad);
private:
const float m_learningRate;
};
// Stochastic gradient descent (SGD) optimizer.
class SGDOptimizer : public virtual Optimizer {
public:
SGDOptimizer(Matrix* x, Matrix* xGrad, const float learningRate);
void update(int device);
private:
Matrix* m_x;
Matrix* m_xGrad;
const float m_learningRate;
float m_currentRate;
};
}