Skip to content

Commit

Permalink
Merge pull request BVLC#190 from sguada/new_lr_policies
Browse files Browse the repository at this point in the history
New lr policies, MultiStep and StepEarly
  • Loading branch information
sguada committed Oct 16, 2014
2 parents 7effdca + e064fe7 commit 3aa2a6d
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 8 deletions.
33 changes: 33 additions & 0 deletions examples/lenet/lenet_multistep_solver.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# The training protocol buffer definition
train_net: "lenet_train.prototxt"
# The testing protocol buffer definition
test_net: "lenet_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "multistep"
gamma: 0.9
stepvalue: 1000
stepvalue: 2000
stepvalue: 2500
stepvalue: 3000
stepvalue: 3500
stepvalue: 4000
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "lenet"
# solver mode: 0 for CPU and 1 for GPU
solver_mode: 1
device_id: 1
28 changes: 28 additions & 0 deletions examples/lenet/lenet_stepearly_solver.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# The training protocol buffer definition
train_net: "lenet_train.prototxt"
# The testing protocol buffer definition
test_net: "lenet_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "stepearly"
gamma: 0.9
stepearly: 1
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "lenet"
# solver mode: 0 for CPU and 1 for GPU
solver_mode: 1
device_id: 1
1 change: 1 addition & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Solver {

SolverParameter param_;
int iter_;
int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;

Expand Down
8 changes: 6 additions & 2 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 34 (last added: average_loss)
// SolverParameter next available ID: 35 (last added: stepvalue)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -126,7 +126,10 @@ message SolverParameter {
// regularization types supported: L1 and L2
// controlled by weight_decay
optional string regularization_type = 29 [default = "L2"];
optional int32 stepsize = 13; // the stepsize for learning rate policy "step"
// the stepsize for learning rate policy "step"
optional int32 stepsize = 13;
// the stepsize for learning rate policy "multistep"
repeated int32 stepvalue = 34;
optional int32 snapshot = 14 [default = 0]; // The snapshot interval
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// whether to snapshot diff in the results or not. Snapshotting diff will help
Expand Down Expand Up @@ -168,6 +171,7 @@ message SolverState {
optional int32 iter = 1; // The current iteration
optional string learned_net = 2; // The file that stores the learned net.
repeated BlobProto history = 3; // The history for sgd solvers
optional int32 current_step = 4 [default = 0]; // The current step for learning rate
}

enum Phase {
Expand Down
37 changes: 31 additions & 6 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
PreSolve();

iter_ = 0;
Expand Down Expand Up @@ -257,7 +258,6 @@ void Solver<Dtype>::TestAll() {
}
}


template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
LOG(INFO) << "Iteration " << iter_
Expand Down Expand Up @@ -336,6 +336,7 @@ void Solver<Dtype>::Snapshot() {
SnapshotSolverState(&state);
state.set_iter(iter_);
state.set_learned_net(model_filename);
state.set_current_step(current_step_);
snapshot_filename = filename + ".solverstate";
LOG(INFO) << "Snapshotting solver state to " << snapshot_filename;
WriteProtoToBinaryFile(state, snapshot_filename.c_str());
Expand All @@ -351,6 +352,7 @@ void Solver<Dtype>::Restore(const char* state_file) {
net_->CopyTrainedLayersFrom(net_param);
}
iter_ = state.iter();
current_step_ = state.current_step();
RestoreSolverState(state);
}

Expand All @@ -361,31 +363,54 @@ void Solver<Dtype>::Restore(const char* state_file) {
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
// where base_lr, gamma, step and power are defined in the solver parameter
// protocol buffer, and iter is the current iteration.
// - multistep: similar to step but it allows non uniform steps defined by
// stepvalue
// - poly: the effective learning rate follows a polynomial decay, to be
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
template <typename Dtype>
Dtype SGDSolver<Dtype>::GetLearningRate() {
Dtype rate;
const string& lr_policy = this->param_.lr_policy();
if (lr_policy == "fixed") {
rate = this->param_.base_lr();
} else if (lr_policy == "step") {
int current_step = this->iter_ / this->param_.stepsize();
this->current_step_ = this->iter_ / this->param_.stepsize();
rate = this->param_.base_lr() *
pow(this->param_.gamma(), current_step);
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "exp") {
rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
} else if (lr_policy == "inv") {
rate = this->param_.base_lr() *
pow(Dtype(1) + this->param_.gamma() * this->iter_,
- this->param_.power());
} else if (lr_policy == "multistep") {
if (this->current_step_ < this->param_.stepvalue_size() &&
this->iter_ >= this->param_.stepvalue(this->current_step_)) {
this->current_step_++;
LOG(INFO) << "MultiStep Status: Iteration " <<
this->iter_ << ", step = " << this->current_step_;
}
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "poly") {
rate = this->param_.base_lr() * pow(Dtype(1.) -
(Dtype(this->iter_) / Dtype(this->param_.max_iter())),
this->param_.power());
} else if (lr_policy == "sigmoid") {
rate = this->param_.base_lr() * (Dtype(1.) /
(Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
Dtype(this->param_.stepsize())))));
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
return rate;
}


template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
// Initialize the history
Expand Down

0 comments on commit 3aa2a6d

Please sign in to comment.