-
Notifications
You must be signed in to change notification settings - Fork 5
/
net.py
46 lines (33 loc) · 1.21 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch, numpy as np
from torch.nn import *
from config import config
class Net(Module):
def __init__(self):
super().__init__()
self.device = torch.device(config.device)
self.lr = config.opt_lr
self.alpha_h = config.alpha_h
def save(self, file='model.pt'):
torch.save(self.state_dict(), file)
def load(self, file='model.pt'):
self.load_state_dict(torch.load(file, map_location=self.device))
def copy_weights(self, other, rho):
params_other = list(other.parameters())
params_self = list(self.parameters())
for i in range( len(params_other) ):
val_self = params_self[i].data
val_other = params_other[i].data
val_new = rho * val_other + (1-rho) * val_self
params_self[i].data.copy_(val_new)
def set_lr(self, lr):
self.lr = lr
for param_group in self.opt.param_groups:
param_group['lr'] = lr
def set_alpha_h(self, alpha_h):
self.alpha_h = alpha_h
def get_param_count(self):
return sum(p.numel() for p in self.parameters())
def reset_state(self, batch_mask=None):
pass
def clone_state(self, other):
pass