Skip to content

Commit 88952fb

Browse files
committed
use existed sgd updater function
1 parent d6a27ad commit 88952fb

File tree

2 files changed

+8
-29
lines changed

2 files changed

+8
-29
lines changed

paddle/math/Vector.h

-22
Original file line numberDiff line numberDiff line change
@@ -92,28 +92,6 @@ class VectorT : public BaseVector<T> {
9292
const T* getData() const { return this->data_; }
9393
T* getData() { return this->data_; }
9494

95-
#ifdef PADDLE_USE_MKLDNN
96-
/**
97-
* sgd update with openmp to speedup
98-
*/
99-
void sgdUpdateWithOMP(VectorT& gradVec,
100-
VectorT& momVec,
101-
T learningRate,
102-
T momentum,
103-
T decayRate) {
104-
size_t size = this->getSize();
105-
T* val = this->getData();
106-
T* grd = gradVec.getData();
107-
T* mom = momVec.getData();
108-
decayRate *= learningRate;
109-
#pragma omp parallel for
110-
for (size_t i = 0; i < size; ++i) {
111-
mom[i] = momentum * mom[i] - learningRate * grd[i] - decayRate * val[i];
112-
val[i] += mom[i];
113-
}
114-
}
115-
#endif
116-
11795
virtual void zeroMem() = 0;
11896
// set all elements to value
11997
virtual void reset(const T& value) = 0;

paddle/parameter/FirstOrderOptimizer.h

+8-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include "ParameterOptimizer.h"
18+
#include "ParameterUpdateFunctions.h"
1819
#include "Regularizer.h"
1920

2021
namespace paddle {
@@ -38,13 +39,13 @@ class SgdOptimizer : public ParameterOptimizer {
3839
? 1.0 - paraConfig.momentum()
3940
: 1.0;
4041
#ifdef PADDLE_USE_MKLDNN
41-
vecs[PARAMETER_VALUE]->sgdUpdateWithOMP(
42-
*vecs[PARAMETER_GRADIENT],
43-
*vecs[PARAMETER_MOMENTUM],
44-
learningRate_ * paraConfig.learning_rate() *
45-
(firstTime_ ? 1.0 : torch_learningRate),
46-
paraConfig.momentum(),
47-
applyDecay_ ? paraConfig.decay_rate() : 0);
42+
sgdUpdate(learningRate_ * paraConfig.learning_rate() *
43+
(firstTime_ ? 1.0 : torch_learningRate),
44+
paraConfig.momentum(),
45+
applyDecay_ ? paraConfig.decay_rate() : 0,
46+
vecs[PARAMETER_VALUE].get(),
47+
vecs[PARAMETER_GRADIENT].get(),
48+
vecs[PARAMETER_MOMENTUM].get());
4849
#else
4950
vecs[PARAMETER_VALUE]->sgdUpdate(
5051
*vecs[PARAMETER_GRADIENT],

0 commit comments

Comments
 (0)