forked from cgpotts/cs224u
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorch_model_base.py
55 lines (49 loc) · 1.7 KB
/
torch_model_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Spring 2019"
class TorchModelBase(object):
def __init__(self,
hidden_dim=50,
hidden_activation=nn.Tanh(),
batch_size=1028,
max_iter=100,
eta=0.01,
optimizer=torch.optim.Adam,
l2_strength=0,
device=None):
self.hidden_dim = hidden_dim
self.hidden_activation = hidden_activation
self.batch_size = batch_size
self.max_iter = max_iter
self.eta = eta
self.optimizer = optimizer
self.l2_strength = l2_strength
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.params = [
'hidden_dim',
'hidden_activation',
'batch_size',
'max_iter',
'eta',
'optimizer',
'l2_strength']
self.errors = []
self.dev_predictions = {}
def get_params(self, deep=True):
params = self.params.copy()
# Obligatorily add `vocab` so that sklearn passes it in when
# creating new model instances during cross-validation:
if hasattr(self, 'vocab'):
params += ['vocab']
return {p: getattr(self, p) for p in params}
def set_params(self, **params):
for key, val in params.items():
setattr(self, key, val)
return self
def __repr__(self):
param_str = ["{}={}".format(a, getattr(self, a)) for a in self.params]
param_str = ",\n\t".join(param_str)
return "{}(\n\t{})".format(self.__class__.__name__, param_str)