Skip to content

Commit 77be09a

Browse files
authored
[math] Add get JIT weight matrix methods(Uniform & Normal) for brainpy.dnn.linear (#673)
* [math] Add get JIT connect matrix methods for `brainpy.dnn.linear` * Update * Update linear.py * [math] Add get JIT weight matrix methods(Uniform & Normal) for `brainpy.dnn.linear` * Update linear.py * Update * Update linear.py * Add test for get_conn_matrix` at `brainpy.dnn.linear` module * Update linear.py * Update * Fix bugs
1 parent 3a09b7b commit 77be09a

File tree

6 files changed

+523
-73
lines changed

6 files changed

+523
-73
lines changed

brainpy/_src/dnn/linear.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -980,15 +980,36 @@ def __init__(
980980
self.sharding = sharding
981981

982982

983-
class JitFPLinear(Layer):
983+
class JitLinear(Layer):
984984
def get_conn_matrix(self):
985-
return bm.jitconn.get_conn_matrix(self.prob, self.seed,
986-
shape=(self.num_out, self.num_in),
987-
transpose=self.transpose,
988-
outdim_parallel=not self.atomic)
985+
pass
986+
987+
988+
class JitFPHomoLayer(JitLinear):
989+
def get_conn_matrix(self):
990+
return bm.jitconn.get_homo_weight_matrix(self.weight, self.prob, self.seed,
991+
shape=(self.num_out, self.num_in),
992+
transpose=self.transpose,
993+
outdim_parallel=not self.atomic)
994+
995+
996+
class JitFPUniformLayer(JitLinear):
997+
def get_conn_matrix(self):
998+
return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed,
999+
shape=(self.num_out, self.num_in),
1000+
transpose=self.transpose,
1001+
outdim_parallel=not self.atomic)
1002+
1003+
1004+
class JitFPNormalLayer(JitLinear):
1005+
def get_conn_matrix(self):
1006+
return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed,
1007+
shape=(self.num_out, self.num_in),
1008+
transpose=self.transpose,
1009+
outdim_parallel=not self.atomic)
9891010

9901011

991-
class JitFPHomoLinear(JitFPLinear):
1012+
class JitFPHomoLinear(JitFPHomoLayer):
9921013
r"""Synaptic matrix multiplication with the just-in-time connectivity.
9931014
9941015
It performs the computation of:
@@ -1067,7 +1088,7 @@ def _batch_mv(self, x):
10671088
outdim_parallel=not self.atomic)
10681089

10691090

1070-
class JitFPUniformLinear(JitFPLinear):
1091+
class JitFPUniformLinear(JitFPUniformLayer):
10711092
r"""Synaptic matrix multiplication with the just-in-time connectivity.
10721093
10731094
It performs the computation of:
@@ -1147,7 +1168,7 @@ def _batch_mv(self, x):
11471168
outdim_parallel=not self.atomic)
11481169

11491170

1150-
class JitFPNormalLinear(JitFPLinear):
1171+
class JitFPNormalLinear(JitFPNormalLayer):
11511172
r"""Synaptic matrix multiplication with the just-in-time connectivity.
11521173
11531174
It performs the computation of:
@@ -1227,7 +1248,7 @@ def _batch_mv(self, x):
12271248
outdim_parallel=not self.atomic)
12281249

12291250

1230-
class EventJitFPHomoLinear(JitFPLinear):
1251+
class EventJitFPHomoLinear(JitFPHomoLayer):
12311252
r"""Synaptic matrix multiplication with the just-in-time connectivity.
12321253
12331254
It performs the computation of:
@@ -1306,7 +1327,7 @@ def _batch_mv(self, x):
13061327
outdim_parallel=not self.atomic)
13071328

13081329

1309-
class EventJitFPUniformLinear(JitFPLinear):
1330+
class EventJitFPUniformLinear(JitFPUniformLayer):
13101331
r"""Synaptic matrix multiplication with the just-in-time connectivity.
13111332
13121333
It performs the computation of:
@@ -1386,7 +1407,7 @@ def _batch_mv(self, x):
13861407
outdim_parallel=not self.atomic)
13871408

13881409

1389-
class EventJitFPNormalLinear(JitFPLinear):
1410+
class EventJitFPNormalLinear(JitFPNormalLayer):
13901411
r"""Synaptic matrix multiplication with the just-in-time connectivity.
13911412
13921413
It performs the computation of:

brainpy/_src/dnn/tests/test_linear.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ def test_JitFPHomoLinear(self, prob, weight, shape):
141141
x = bm.random.random(shape + (100,))
142142
y = f(x)
143143
self.assertTrue(y.shape == shape + (200,))
144+
145+
conn_matrix = f.get_conn_matrix()
146+
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
147+
# print(conn_matrix.shape)
148+
# self.assertTrue(conn_matrix.shape == (200, 100))
144149
bm.clear_buffer_memory()
145150

146151
@parameterized.product(
@@ -155,6 +160,9 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape):
155160
x = bm.random.random(shape + (100,))
156161
y = f(x)
157162
self.assertTrue(y.shape == shape + (200,))
163+
164+
conn_matrix = f.get_conn_matrix()
165+
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
158166
bm.clear_buffer_memory()
159167

160168
@parameterized.product(
@@ -169,6 +177,9 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
169177
x = bm.random.random(shape + (100,))
170178
y = f(x)
171179
self.assertTrue(y.shape == shape + (200,))
180+
181+
conn_matrix = f.get_conn_matrix()
182+
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
172183
bm.clear_buffer_memory()
173184

174185
@parameterized.product(
@@ -179,11 +190,15 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
179190
def test_EventJitFPHomoLinear(self, prob, weight, shape):
180191
bm.random.seed()
181192
f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123)
182-
y = f(bm.random.random(shape + (100,)) < 0.1)
193+
x = bm.random.random(shape + (100,)) < 0.1
194+
y = f(x)
183195
self.assertTrue(y.shape == shape + (200,))
184196

185197
y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
186198
self.assertTrue(y2.shape == shape + (200,))
199+
200+
conn_matrix = f.get_conn_matrix()
201+
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
187202
bm.clear_buffer_memory()
188203

189204
@parameterized.product(
@@ -195,11 +210,15 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape):
195210
def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
196211
bm.random.seed()
197212
f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123)
198-
y = f(bm.random.random(shape + (100,)) < 0.1)
213+
x = bm.random.random(shape + (100,)) < 0.1
214+
y = f(x)
199215
self.assertTrue(y.shape == shape + (200,))
200216

201217
y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
202218
self.assertTrue(y2.shape == shape + (200,))
219+
220+
conn_matrix = f.get_conn_matrix()
221+
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
203222
bm.clear_buffer_memory()
204223

205224
@parameterized.product(
@@ -211,11 +230,15 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
211230
def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
212231
bm.random.seed()
213232
f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123)
214-
y = f(bm.random.random(shape + (100,)) < 0.1)
233+
x = bm.random.random(shape + (100,)) < 0.1
234+
y = f(x)
215235
self.assertTrue(y.shape == shape + (200,))
216236

217237
y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
218238
self.assertTrue(y2.shape == shape + (200,))
239+
240+
conn_matrix = f.get_conn_matrix()
241+
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
219242
bm.clear_buffer_memory()
220243

221244

0 commit comments

Comments
 (0)