Skip to content

Commit cc6cf80

Browse files
committed
use vector
1 parent ad9d1ce commit cc6cf80

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

cpp/layer.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ FullyConnect::~FullyConnect() {
3030
this->W.clear();
3131
this->W.shrink_to_fit();
3232
if (this->phase == TRAIN) {
33-
free(this->delta_buf);
33+
this->delta_buf.clear();
34+
this->delta_buf.shrink_to_fit();
3435
}
3536
}
3637

@@ -40,7 +41,7 @@ int FullyConnect::configure(int batch, float learning_rate, float v_param, Layer
4041
this->W.resize(this->input_shape*this->units);
4142
if (this->phase == TRAIN) {
4243
this->E.resize(this->batch*this->input_shape);
43-
this->delta_buf = (float*)malloc(sizeof(float)*this->batch*this->units);
44+
this->delta_buf.resize(this->batch*this->units);
4445
}
4546
std::random_device rd;
4647
std::mt19937 mt(rd());
@@ -136,7 +137,8 @@ Conv2D::~Conv2D() {
136137
this->F.clear();
137138
this->F.shrink_to_fit();
138139
if (this->phase == TRAIN) {
139-
free(this->delta_buf);
140+
this->delta_buf.clear();
141+
this->delta_buf.shrink_to_fit();
140142
}
141143
}
142144

@@ -146,7 +148,7 @@ int Conv2D::configure(int batch, float learning_rate, float v_param, Layer* prev
146148
//this->X = (float*)malloc(sizeof(float)*this->batch*this->channel*this->kernel_size*this->kernel_size*this->units*this->units);
147149
this->Y.resize(this->batch*this->filter*this->units);
148150
this->E.resize(this->batch*this->channel*this->input_shape);
149-
this->delta_buf = (float*)malloc(sizeof(float)*this->batch*this->filter*this->units);
151+
this->delta_buf.resize(this->batch*this->filter*this->units);
150152
this->F.resize(this->filter*this->kernel_size*this->kernel_size);
151153

152154
std::random_device rd;

cpp/layer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Layer {
2424
phase_t phase;
2525
vector<float> E;
2626
// for Momentum
27-
float* delta_buf;
27+
vector<float> delta_buf;
2828
float momentum_a;
2929
vector<float> Y;
3030
vector<float> *X;

0 commit comments

Comments
 (0)