forked from huichen/mlf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gd.go
114 lines (96 loc) · 3.13 KB
/
gd.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package optimizer
import (
"github.com/huichen/mlf/data"
"github.com/huichen/mlf/util"
"log"
"math"
)
// 梯度递降(Gradient Descent)优化器
type gdOptimizer struct {
options OptimizerOptions
}
// 初始化优化结构体
func NewGdOptimizer(options OptimizerOptions) Optimizer {
opt := new(gdOptimizer)
opt.options = options
return opt
}
// 清除结构体中保存的数据,以便重复使用结构体
func (opt *gdOptimizer) Clear() {
}
// 输入x_k和g_k,返回x需要更新的增量 d_k = - g_k
func (opt *gdOptimizer) GetDeltaX(x, g *util.Matrix) *util.Matrix {
return g.Opposite()
}
func (opt *gdOptimizer) OptimizeWeights(
weights *util.Matrix, derivative_func ComputeInstanceDerivativeFunc, set data.Dataset) {
// 偏导数向量
derivative := weights.Populate()
// 学习率计算器
learningRate := NewLearningRate(opt.options)
// 优化循环
iterator := set.CreateIterator()
step := 0
var learning_rate float64
convergingSteps := 0
oldWeights := weights.Populate()
weightsDelta := weights.Populate()
instanceDerivative := weights.Populate()
log.Print("开始梯度递降优化")
for {
if opt.options.MaxIterations > 0 && step >= opt.options.MaxIterations {
break
}
step++
// 每次遍历样本前对偏导数向量清零
derivative.Clear()
// 遍历所有样本,计算偏导数向量并累加
iterator.Start()
instancesProcessed := 0
for !iterator.End() {
instance := iterator.GetInstance()
derivative_func(weights, instance, instanceDerivative)
derivative.Increment(instanceDerivative, 1.0/float64(set.NumInstances()))
iterator.Next()
instancesProcessed++
if opt.options.GDBatchSize > 0 && instancesProcessed >= opt.options.GDBatchSize {
// 添加正则化项
derivative.Increment(ComputeRegularization(weights, opt.options),
float64(instancesProcessed)/(float64(set.NumInstances())*float64(set.NumInstances())))
// 计算特征权重的增量
delta := opt.GetDeltaX(weights, derivative)
// 根据学习率更新权重
learning_rate = learningRate.ComputeLearningRate(delta)
weights.Increment(delta, learning_rate)
// 重置
derivative.Clear()
instancesProcessed = 0
}
}
if instancesProcessed > 0 {
// 处理剩余的样本
derivative.Increment(ComputeRegularization(weights, opt.options),
float64(instancesProcessed)/(float64(set.NumInstances())*float64(set.NumInstances())))
delta := opt.GetDeltaX(weights, derivative)
learning_rate = learningRate.ComputeLearningRate(delta)
weights.Increment(delta, learning_rate)
}
weightsDelta.WeightedSum(weights, oldWeights, 1, -1)
oldWeights.DeepCopy(weights)
weightsNorm := weights.Norm()
weightsDeltaNorm := weightsDelta.Norm()
log.Printf("#%d |w|=%1.3g |dw|/|w|=%1.3g lr=%1.3g", step, weightsNorm, weightsDeltaNorm/weightsNorm, learning_rate)
// 判断是否溢出
if math.IsNaN(weightsNorm) {
log.Fatal("优化失败:不收敛")
}
// 判断是否收敛
if weightsDelta.Norm()/weights.Norm() < opt.options.ConvergingDeltaWeight {
convergingSteps++
if convergingSteps > opt.options.ConvergingSteps {
log.Printf("收敛")
break
}
}
}
}