-
Notifications
You must be signed in to change notification settings - Fork 114
Expand file tree
/
Copy pathoptimizer.ts
More file actions
119 lines (101 loc) · 3.09 KB
/
optimizer.ts
File metadata and controls
119 lines (101 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import { Tensor, native } from "./tensor.js";
import { Parameter } from "./module.js";
export type ParameterValue = Tensor;
export class Optimizer {
parameters: Parameter<Tensor>[];
constructor(parameters: Parameter<Tensor>[]) {
this.parameters = parameters;
}
}
export class SGD extends Optimizer {
lr: number;
constructor(parameters: Parameter<Tensor>[], lr: number = 1.0) {
super(parameters);
this.lr = lr;
}
zeroGrad() {
const ids = this.parameters.map(p => p.value._id);
native.zeroGrad(ids);
}
step() {
// Simple SGD: p = p - lr * grad
for (const p of this.parameters) {
const grad = p.value.grad;
if (!grad) continue;
const updated = p.value.sub(grad.mul(this.lr));
p.update(updated as any);
}
}
}
export class Adam extends Optimizer {
lr: number;
beta1: number;
beta2: number;
eps: number;
weightDecay: number;
t: number = 0;
constructor(
parameters: Parameter<Tensor>[],
{ lr = 6e-4, beta1 = 0.9, beta2 = 0.95, eps = 1e-8, weightDecay = 0 } = {}
) {
super(parameters);
this.lr = lr;
this.beta1 = beta1;
this.beta2 = beta2;
this.eps = eps;
this.weightDecay = weightDecay;
}
zeroGrad() {
const ids = this.parameters.map(p => p.value._id);
native.zeroGrad(ids);
}
step(maxGradNorm: number = 0): number {
this.t++;
const ids = this.parameters.map(p => p.value._id);
return native.clipAndStep(
ids, this.lr, this.beta1, this.beta2, this.eps, this.weightDecay, this.t, maxGradNorm
);
}
}
export class GradScaler {
private scale: number;
private growthFactor: number;
private backoffFactor: number;
private growthInterval: number;
private stepsSinceGrowth: number = 0;
constructor({
initScale = 65536.0,
growthFactor = 2.0,
backoffFactor = 0.5,
growthInterval = 2000,
} = {}) {
this.scale = initScale;
this.growthFactor = growthFactor;
this.backoffFactor = backoffFactor;
this.growthInterval = growthInterval;
}
getScale(): number {
return this.scale;
}
scaleLoss(loss: Tensor): Tensor {
return loss.mul(this.scale);
}
unscaleAndStep(optimizer: Adam, maxGradNorm: number = 0): { gradNorm: number; skipped: boolean } {
const ids = optimizer.parameters.map(p => p.value._id);
const invScale = 1.0 / this.scale;
const foundInf = native.scaleGrads(ids, invScale);
if (foundInf) {
this.scale *= this.backoffFactor;
this.stepsSinceGrowth = 0;
optimizer.zeroGrad();
return { gradNorm: 0, skipped: true };
}
const gradNorm = optimizer.step(maxGradNorm);
this.stepsSinceGrowth++;
if (this.stepsSinceGrowth >= this.growthInterval) {
this.scale *= this.growthFactor;
this.stepsSinceGrowth = 0;
}
return { gradNorm, skipped: false };
}
}