File tree 2 files changed +8
-29
lines changed
2 files changed +8
-29
lines changed Original file line number Diff line number Diff line change @@ -92,28 +92,6 @@ class VectorT : public BaseVector<T> {
92
92
const T* getData () const { return this ->data_ ; }
93
93
T* getData () { return this ->data_ ; }
94
94
95
- #ifdef PADDLE_USE_MKLDNN
96
- /* *
97
- * sgd update with openmp to speedup
98
- */
99
- void sgdUpdateWithOMP (VectorT& gradVec,
100
- VectorT& momVec,
101
- T learningRate,
102
- T momentum,
103
- T decayRate) {
104
- size_t size = this ->getSize ();
105
- T* val = this ->getData ();
106
- T* grd = gradVec.getData ();
107
- T* mom = momVec.getData ();
108
- decayRate *= learningRate;
109
- #pragma omp parallel for
110
- for (size_t i = 0 ; i < size; ++i) {
111
- mom[i] = momentum * mom[i] - learningRate * grd[i] - decayRate * val[i];
112
- val[i] += mom[i];
113
- }
114
- }
115
- #endif
116
-
117
95
virtual void zeroMem () = 0;
118
96
// set all elements to value
119
97
virtual void reset (const T& value) = 0;
Original file line number Diff line number Diff line change @@ -15,6 +15,7 @@ limitations under the License. */
15
15
#pragma once
16
16
17
17
#include " ParameterOptimizer.h"
18
+ #include " ParameterUpdateFunctions.h"
18
19
#include " Regularizer.h"
19
20
20
21
namespace paddle {
@@ -38,13 +39,13 @@ class SgdOptimizer : public ParameterOptimizer {
38
39
? 1.0 - paraConfig.momentum ()
39
40
: 1.0 ;
40
41
#ifdef PADDLE_USE_MKLDNN
41
- vecs[PARAMETER_VALUE]-> sgdUpdateWithOMP (
42
- *vecs[PARAMETER_GRADIENT] ,
43
- *vecs[PARAMETER_MOMENTUM] ,
44
- learningRate_ * paraConfig.learning_rate () *
45
- (firstTime_ ? 1.0 : torch_learningRate ),
46
- paraConfig. momentum (),
47
- applyDecay_ ? paraConfig. decay_rate () : 0 );
42
+ sgdUpdate (learningRate_ * paraConfig. learning_rate () *
43
+ (firstTime_ ? 1.0 : torch_learningRate) ,
44
+ paraConfig. momentum () ,
45
+ applyDecay_ ? paraConfig.decay_rate () : 0 ,
46
+ vecs[PARAMETER_VALUE]. get ( ),
47
+ vecs[PARAMETER_GRADIENT]. get (),
48
+ vecs[PARAMETER_MOMENTUM]. get () );
48
49
#else
49
50
vecs[PARAMETER_VALUE]->sgdUpdate (
50
51
*vecs[PARAMETER_GRADIENT],
You can’t perform that action at this time.
0 commit comments