Skip to content

Commit

Permalink
Update optim.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tiendung authored Oct 15, 2024
1 parent 5e7a3a6 commit 444af45
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion kim/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class Optimizer:
def __init__(self, params):
# Ghi nhớ lại những params cần được update
self.params = [x for x in params if x.update_params]
print(">>> optim params", len(self.params))

Expand All @@ -25,14 +26,15 @@ def __init__(self, params, lr=0.01, momentum=0.0, weight_decay=0.0, device=None)
self.weight_decay = weight_decay
self.u = {}
for w in self.params:
# Với mỗi param (weight w) cần khởi tạo chỉ số phụ u bằng chính kích thước của w => x2 vram
self.u[w] = device.zeros(*w.shape)

def step(self):
for w in self.params:
grad = w.grad.cached_data + w.cached_data * self.weight_decay
self.u[w] = self.momentum*self.u[w] + (1 - self.momentum)*grad
w.cached_data -= self.lr * self.u[w]

# u của w được update sau mỗi step, và cần được lưu lại, nếu offload thì phải offload cả u

class Adam(Optimizer):
def __init__(
Expand Down

0 comments on commit 444af45

Please sign in to comment.