-
Notifications
You must be signed in to change notification settings - Fork 58
/
focal_loss_layer.cpp
154 lines (143 loc) · 5.7 KB
/
focal_loss_layer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include <algorithm>
#include <cfloat>
#include <vector>
#include "caffe/layers/focal_loss_layer.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
void FocalLossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
LossLayer<Dtype>::LayerSetUp(bottom, top);
LayerParameter softmax_param(this->layer_param_);
softmax_param.set_type("Softmax");
softmax_layer_ = LayerRegistry<Dtype>::CreateLayer(softmax_param);
softmax_bottom_vec_.clear();
softmax_bottom_vec_.push_back(bottom[0]);
softmax_top_vec_.clear();
softmax_top_vec_.push_back(&prob_);
softmax_layer_->SetUp(softmax_bottom_vec_, softmax_top_vec_);
alpha_ = this->layer_param_.focal_loss_param().alpha();
gamma_ = this->layer_param_.focal_loss_param().gamma();
has_ignore_label_ =
this->layer_param_.loss_param().has_ignore_label();
if (has_ignore_label_) {
ignore_label_ = this->layer_param_.loss_param().ignore_label();
}
if (!this->layer_param_.loss_param().has_normalization() &&
this->layer_param_.loss_param().has_normalize()) {
normalization_ = this->layer_param_.loss_param().normalize() ?
LossParameter_NormalizationMode_VALID :
LossParameter_NormalizationMode_BATCH_SIZE;
} else {
normalization_ = this->layer_param_.loss_param().normalization();
}
}
template <typename Dtype>
void FocalLossLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
LossLayer<Dtype>::Reshape(bottom, top);
softmax_layer_->Reshape(softmax_bottom_vec_, softmax_top_vec_);
softmax_axis_ =
bottom[0]->CanonicalAxisIndex(this->layer_param_.focal_loss_param().axis());
outer_num_ = bottom[0]->count(0, softmax_axis_); // n
inner_num_ = bottom[0]->count(softmax_axis_ + 1);// h * w
CHECK_EQ(outer_num_ * inner_num_, bottom[1]->count())
<< "Number of labels must match number of predictions; "
<< "e.g., if softmax axis == 1 and prediction shape is (N, C, H, W), "
<< "label count (number of labels) must be N*H*W, "
<< "with integer values in {0, 1, ..., C-1}.";
if (top.size() >= 2) {
// softmax output
top[1]->ReshapeLike(*bottom[0]);
}
}
template <typename Dtype>
void FocalLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
// The forward pass computes the softmax prob values.
softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_);
const Dtype* prob_data = prob_.cpu_data();
const Dtype* label = bottom[1]->cpu_data();
int dim = prob_.count() / outer_num_;
int count = 0;
Dtype loss = 0;
Dtype pt = 0;
for (int i = 0; i < outer_num_; ++i) {
for (int j = 0; j < inner_num_; j++) {
const int label_value = static_cast<int>(label[i * inner_num_ + j]);
if (has_ignore_label_ && label_value == ignore_label_) {
continue;
}
DCHECK_GE(label_value, 0);
DCHECK_LT(label_value, prob_.shape(softmax_axis_));
//loss -= log(std::max(prob_data[i * dim + label_value * inner_num_ + j],
// Dtype(FLT_MIN)));
pt = prob_data[i * dim + label_value * inner_num_ + j];
loss -= alpha_ * pow(1.0 - pt, gamma_) * log(std::max(pt, Dtype(FLT_MIN)));
++count;
}
}
Dtype normalizer = LossLayer<Dtype>::GetNormalizer(
normalization_, outer_num_, inner_num_, count);
top[0]->mutable_cpu_data()[0] = loss / normalizer;
if (top.size() == 2) {
top[1]->ShareData(prob_);
}
}
template <typename Dtype>
void FocalLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[1]) {
LOG(FATAL) << this->type()
<< " Layer cannot backpropagate to label inputs.";
}
if (propagate_down[0]) {
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
const Dtype* prob_data = prob_.cpu_data();
//caffe_copy(prob_.count(), prob_data, bottom_diff);
const Dtype* label = bottom[1]->cpu_data();
int dim = prob_.count() / outer_num_;
int count = 0;
Dtype focal_diff = 0;
Dtype pt = 0;
Dtype pc = 0;
for (int i = 0; i < outer_num_; ++i) {
for (int j = 0; j < inner_num_; ++j) {
const int label_value = static_cast<int>(label[i * inner_num_ + j]);
if (has_ignore_label_ && label_value == ignore_label_) {
for (int c = 0; c < bottom[0]->shape(softmax_axis_); ++c) {
bottom_diff[i * dim + c * inner_num_ + j] = 0;
}
} else {
//bottom_diff[i * dim + label_value * inner_num_ + j] -= 1;
pt = prob_data[i * dim + label_value * inner_num_ + j];
for (int c = 0; c < bottom[0]->shape(softmax_axis_); ++c) {
pc = prob_data[i * dim + c * inner_num_ + j];
if(c == label_value){
focal_diff = alpha_ *
pow(1 - pt, gamma_) * (gamma_ * pt * log(std::max(pt, Dtype(FLT_MIN))) + pt - 1);
}
else{
focal_diff = alpha_ *
(pow(1 - pt, gamma_ - 1) * (-gamma_ * log(std::max(pt, Dtype(FLT_MIN))) * pt * pc)
+ pow(1 - pt, gamma_) * pc);
}
bottom_diff[i * dim + c * inner_num_ + j] = focal_diff;
}
++count;
}
}
}
// Scale gradient
Dtype normalizer = LossLayer<Dtype>::GetNormalizer(
normalization_, outer_num_, inner_num_, count);
Dtype loss_weight = top[0]->cpu_diff()[0] / normalizer;
caffe_scal(prob_.count(), loss_weight, bottom_diff);
}
}
#ifdef CPU_ONLY
STUB_GPU(FocalLossLayer);
#endif
INSTANTIATE_CLASS(FocalLossLayer);
REGISTER_LAYER_CLASS(FocalLoss);
} // namespace caffe