Skip to content

Commit d3ced3d

Browse files
add batch units
1 parent 6b3a6d7 commit d3ced3d

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed

rnn_class/batch_units.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# https://deeplearningcourses.com/c/deep-learning-recurrent-neural-networks-in-python
2+
# https://udemy.com/deep-learning-recurrent-neural-networks-in-python
3+
import numpy as np
4+
import theano
5+
import theano.tensor as T
6+
7+
def init_weight(Mi, Mo):
8+
return np.random.randn(Mi, Mo) * np.sqrt(2.0 / Mi)
9+
10+
11+
class SimpleRecurrentLayer:
12+
def __init__(self, Mi, Mo, activation):
13+
self.Mi = Mi
14+
self.Mo = Mo
15+
self.f = activation
16+
17+
# numpy init
18+
Wxh = init_weight(Mi, Mo)
19+
Whh = init_weight(Mo, Mo)
20+
b = np.zeros(Mo)
21+
h0 = np.zeros(Mo)
22+
23+
# theano vars
24+
self.Wxh = theano.shared(Wxh)
25+
self.Whh = theano.shared(Whh)
26+
self.b = theano.shared(b)
27+
self.h0 = theano.shared(h0)
28+
self.params = [self.Wxh, self.Whh, self.b, self.h0]
29+
30+
def get_ht(self, xWxh_t, h_t1):
31+
return self.f(xWxh_t + h_t1.dot(self.Whh) + self.b)
32+
33+
def recurrence(self, xWxh_t, is_start, h_t1, h0):
34+
h_t = T.switch(
35+
T.eq(is_start, 1),
36+
self.get_ht(xWxh_t, h0),
37+
self.get_ht(xWxh_t, h_t1)
38+
)
39+
return h_t
40+
41+
def output(self, Xflat, startPoints):
42+
# print("inside output()")
43+
44+
# Xflat should be (NT, D)
45+
# calculate X after multiplying input weights
46+
XWxh = Xflat.dot(self.Wxh)
47+
48+
h, _ = theano.scan(
49+
fn=self.recurrence,
50+
sequences=[XWxh, startPoints],
51+
outputs_info=[self.h0],
52+
non_sequences=[self.h0],
53+
n_steps=Xflat.shape[0],
54+
)
55+
return h
56+
57+
58+
class GRU:
59+
def __init__(self, Mi, Mo, activation):
60+
self.Mi = Mi
61+
self.Mo = Mo
62+
self.f = activation
63+
64+
# numpy init
65+
Wxr = init_weight(Mi, Mo)
66+
Whr = init_weight(Mo, Mo)
67+
br = np.zeros(Mo)
68+
Wxz = init_weight(Mi, Mo)
69+
Whz = init_weight(Mo, Mo)
70+
bz = np.zeros(Mo)
71+
Wxh = init_weight(Mi, Mo)
72+
Whh = init_weight(Mo, Mo)
73+
bh = np.zeros(Mo)
74+
h0 = np.zeros(Mo)
75+
76+
# theano vars
77+
self.Wxr = theano.shared(Wxr)
78+
self.Whr = theano.shared(Whr)
79+
self.br = theano.shared(br)
80+
self.Wxz = theano.shared(Wxz)
81+
self.Whz = theano.shared(Whz)
82+
self.bz = theano.shared(bz)
83+
self.Wxh = theano.shared(Wxh)
84+
self.Whh = theano.shared(Whh)
85+
self.bh = theano.shared(bh)
86+
self.h0 = theano.shared(h0)
87+
self.params = [self.Wxr, self.Whr, self.br, self.Wxz, self.Whz, self.bz, self.Wxh, self.Whh, self.bh, self.h0]
88+
89+
def get_ht(self, xWxr_t, xWxz_t, xWxh_t, h_t1):
90+
r = T.nnet.sigmoid(xWxr_t + h_t1.dot(self.Whr) + self.br)
91+
z = T.nnet.sigmoid(xWxz_t + h_t1.dot(self.Whz) + self.bz)
92+
hhat = self.f(xWxh_t + (r * h_t1).dot(self.Whh) + self.bh)
93+
h = (1 - z) * h_t1 + z * hhat
94+
return h
95+
96+
def recurrence(self, xWxr_t, xWxz_t, xWxh_t, is_start, h_t1, h0):
97+
h_t = T.switch(
98+
T.eq(is_start, 1),
99+
self.get_ht(xWxr_t, xWxz_t, xWxh_t, h0),
100+
self.get_ht(xWxr_t, xWxz_t, xWxh_t, h_t1)
101+
)
102+
return h_t
103+
104+
def output(self, Xflat, startPoints):
105+
# print("inside output()")
106+
107+
# Xflat should be (NT, D)
108+
# calculate X after multiplying input weights
109+
XWxr = Xflat.dot(self.Wxr)
110+
XWxz = Xflat.dot(self.Wxz)
111+
XWxh = Xflat.dot(self.Wxh)
112+
113+
h, _ = theano.scan(
114+
fn=self.recurrence,
115+
sequences=[XWxr, XWxz, XWxh, startPoints],
116+
outputs_info=[self.h0],
117+
non_sequences=[self.h0],
118+
n_steps=Xflat.shape[0],
119+
)
120+
return h
121+
122+
123+
124+
class LSTM:
125+
def __init__(self, Mi, Mo, activation):
126+
self.Mi = Mi
127+
self.Mo = Mo
128+
self.f = activation
129+
130+
# numpy init
131+
Wxi = init_weight(Mi, Mo)
132+
Whi = init_weight(Mo, Mo)
133+
Wci = init_weight(Mo, Mo)
134+
bi = np.zeros(Mo)
135+
Wxf = init_weight(Mi, Mo)
136+
Whf = init_weight(Mo, Mo)
137+
Wcf = init_weight(Mo, Mo)
138+
bf = np.zeros(Mo)
139+
Wxc = init_weight(Mi, Mo)
140+
Whc = init_weight(Mo, Mo)
141+
bc = np.zeros(Mo)
142+
Wxo = init_weight(Mi, Mo)
143+
Who = init_weight(Mo, Mo)
144+
Wco = init_weight(Mo, Mo)
145+
bo = np.zeros(Mo)
146+
c0 = np.zeros(Mo)
147+
h0 = np.zeros(Mo)
148+
149+
# theano vars
150+
self.Wxi = theano.shared(Wxi)
151+
self.Whi = theano.shared(Whi)
152+
self.Wci = theano.shared(Wci)
153+
self.bi = theano.shared(bi)
154+
self.Wxf = theano.shared(Wxf)
155+
self.Whf = theano.shared(Whf)
156+
self.Wcf = theano.shared(Wcf)
157+
self.bf = theano.shared(bf)
158+
self.Wxc = theano.shared(Wxc)
159+
self.Whc = theano.shared(Whc)
160+
self.bc = theano.shared(bc)
161+
self.Wxo = theano.shared(Wxo)
162+
self.Who = theano.shared(Who)
163+
self.Wco = theano.shared(Wco)
164+
self.bo = theano.shared(bo)
165+
self.c0 = theano.shared(c0)
166+
self.h0 = theano.shared(h0)
167+
self.params = [
168+
self.Wxi,
169+
self.Whi,
170+
self.Wci,
171+
self.bi,
172+
self.Wxf,
173+
self.Whf,
174+
self.Wcf,
175+
self.bf,
176+
self.Wxc,
177+
self.Whc,
178+
self.bc,
179+
self.Wxo,
180+
self.Who,
181+
self.Wco,
182+
self.bo,
183+
self.c0,
184+
self.h0,
185+
]
186+
187+
def get_ht_ct(self, xWxi_t, xWxf_t, xWxc_t, xWxo_t, h_t1, c_t1):
188+
i_t = T.nnet.sigmoid(xWxi_t + h_t1.dot(self.Whi) + c_t1.dot(self.Wci) + self.bi)
189+
f_t = T.nnet.sigmoid(xWxf_t + h_t1.dot(self.Whf) + c_t1.dot(self.Wcf) + self.bf)
190+
c_t = f_t * c_t1 + i_t * T.tanh(xWxc_t + h_t1.dot(self.Whc) + self.bc)
191+
o_t = T.nnet.sigmoid(xWxo_t + h_t1.dot(self.Who) + c_t.dot(self.Wco) + self.bo)
192+
h_t = o_t * T.tanh(c_t)
193+
return h_t, c_t
194+
195+
def recurrence(self, xWxi_t, xWxf_t, xWxc_t, xWxo_t, is_start, h_t1, c_t1, h0, c0):
196+
h_t_c_t = T.switch(
197+
T.eq(is_start, 1),
198+
self.get_ht_ct(xWxi_t, xWxf_t, xWxc_t, xWxo_t, h0, c0),
199+
self.get_ht_ct(xWxi_t, xWxf_t, xWxc_t, xWxo_t, h_t1, c_t1)
200+
)
201+
return h_t_c_t[0], h_t_c_t[1]
202+
203+
def output(self, Xflat, startPoints):
204+
# Xflat should be (NT, D)
205+
# calculate X after multiplying input weights
206+
XWxi = Xflat.dot(self.Wxi)
207+
XWxf = Xflat.dot(self.Wxf)
208+
XWxc = Xflat.dot(self.Wxc)
209+
XWxo = Xflat.dot(self.Wxo)
210+
211+
[h, c], _ = theano.scan(
212+
fn=self.recurrence,
213+
sequences=[XWxi, XWxf, XWxc, XWxo, startPoints],
214+
outputs_info=[self.h0, self.c0],
215+
non_sequences=[self.h0, self.c0],
216+
n_steps=Xflat.shape[0],
217+
)
218+
return h

0 commit comments

Comments
 (0)