Skip to content

Commit a57ac95

Browse files
authored
Add files via upload
1 parent fc4593a commit a57ac95

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

rbm.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
2+
import torch
3+
import torch.nn as nn
4+
import torch.optim as optim
5+
import torch.nn.functional as F
6+
from torch.autograd import Variable
7+
8+
9+
class RBM(nn.Module):
10+
11+
def __init__(self, vis_dim, hid_dim, k, learning_rate=0.1, use_cuda=True):
12+
13+
super(RBM, self).__init__()
14+
15+
self.W = nn.Parameter(torch.randn(vis_dim, hid_dim) * 0.01)
16+
self.v_bias = nn.Parameter(torch.zeros(vis_dim))
17+
self.h_bias = nn.Parameter(torch.zeros(hid_dim))
18+
19+
self.k = k
20+
self.learning_rate = learning_rate
21+
self.use_cuda = use_cuda
22+
23+
self.optimizer = optim.SGD(self.parameters(), lr=learning_rate)
24+
25+
if torch.cuda.is_available() and self.use_cuda:
26+
self.cuda()
27+
28+
def sample_h_given_v(self, v_s):
29+
h_p = F.sigmoid(F.linear(v_s, self.W.t(), self.h_bias))
30+
h_s = torch.bernoulli(h_p)
31+
return [h_p, h_s]
32+
33+
def sample_v_given_h(self, h_s):
34+
v_p = F.sigmoid(F.linear(h_s, self.W, self.v_bias))
35+
v_s = torch.bernoulli(v_p)
36+
return [v_p, v_s]
37+
38+
def gibbs_hvh(self, h_s):
39+
v_p, v_s = self.sample_v_given_h(h_s)
40+
h_p, h_s = self.sample_h_given_v(v_s)
41+
return [v_p, v_s, h_p, h_s]
42+
43+
def gibbs_vhv(self, v_s):
44+
h_p, h_s = self.sample_h_given_v(v_s)
45+
v_p, v_s = self.sample_v_given_h(h_s)
46+
return [h_p, h_s, v_p, v_s]
47+
48+
def free_energy(self, v):
49+
v_bias_term = torch.mv(v, self.v_bias)
50+
wx_b = F.linear(v, self.W.t(), self.h_bias)
51+
hidden_term = torch.sum(torch.log(1 + torch.exp(wx_b)), dim=1)
52+
return -v_bias_term - hidden_term
53+
54+
def fit(self, x):
55+
56+
if torch.cuda.is_available() and self.use_cuda:
57+
x = x.cuda()
58+
v_s = Variable(x)
59+
60+
# calculate positive part :: 'p' stands for positive
61+
ph_p, ph_s = self.sample_h_given_v(v_s)
62+
63+
# calculate negative part :: 'n' stands for negative
64+
nv_p, nv_s, nh_p, nh_s = None, None, None, ph_s
65+
for _ in range(self.k):
66+
nv_p, nv_s, nh_p, nh_s = self.gibbs_hvh(nh_s)
67+
68+
# calculate loss
69+
nv_s = nv_s.detach()
70+
cost = torch.mean(self.free_energy(v_s)) - torch.mean(self.free_energy(nv_s))
71+
72+
# calculate gradient & update parameters
73+
self.optimizer.zero_grad()
74+
cost.backward()
75+
self.optimizer.step()
76+
77+
# calculate cross entropy
78+
loss = self.cal_cross_entropy(v_s, nv_p)
79+
80+
return cost.data[0], loss
81+
82+
@staticmethod
83+
def cal_cross_entropy(p, p_):
84+
return torch.mean(torch.sum(p * torch.log(p_) + (1 - p) * torch.log(1 - p_), dim=1))
85+
86+
def reconstruct(self, x):
87+
88+
if torch.cuda.is_available():
89+
x = x.cuda()
90+
v = Variable(x)
91+
92+
h_p,_ = self.sample_h_given_v(v)
93+
94+
return h_p
95+
96+
def compress(self, x):
97+
98+
if torch.cuda.is_available() and self.use_cuda:
99+
x = x.cuda()
100+
v_s = Variable(x)
101+
102+
h_p, h_s = self.sample_h_given_v(v_s)
103+
v_p, v_s = self.sample_v_given_h(h_s)
104+
105+
return v_s
106+

0 commit comments

Comments
 (0)