@@ -141,6 +141,11 @@ def test_JitFPHomoLinear(self, prob, weight, shape):
141
141
x = bm .random .random (shape + (100 ,))
142
142
y = f (x )
143
143
self .assertTrue (y .shape == shape + (200 ,))
144
+
145
+ conn_matrix = f .get_conn_matrix ()
146
+ self .assertTrue (bm .allclose (y , x @ conn_matrix .T ))
147
+ # print(conn_matrix.shape)
148
+ # self.assertTrue(conn_matrix.shape == (200, 100))
144
149
bm .clear_buffer_memory ()
145
150
146
151
@parameterized .product (
@@ -155,6 +160,9 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape):
155
160
x = bm .random .random (shape + (100 ,))
156
161
y = f (x )
157
162
self .assertTrue (y .shape == shape + (200 ,))
163
+
164
+ conn_matrix = f .get_conn_matrix ()
165
+ self .assertTrue (bm .allclose (y , x @ conn_matrix .T ))
158
166
bm .clear_buffer_memory ()
159
167
160
168
@parameterized .product (
@@ -169,6 +177,9 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
169
177
x = bm .random .random (shape + (100 ,))
170
178
y = f (x )
171
179
self .assertTrue (y .shape == shape + (200 ,))
180
+
181
+ conn_matrix = f .get_conn_matrix ()
182
+ self .assertTrue (bm .allclose (y , x @ conn_matrix .T ))
172
183
bm .clear_buffer_memory ()
173
184
174
185
@parameterized .product (
@@ -179,11 +190,15 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
179
190
def test_EventJitFPHomoLinear (self , prob , weight , shape ):
180
191
bm .random .seed ()
181
192
f = bp .dnn .EventJitFPHomoLinear (100 , 200 , prob , weight , seed = 123 )
182
- y = f (bm .random .random (shape + (100 ,)) < 0.1 )
193
+ x = bm .random .random (shape + (100 ,)) < 0.1
194
+ y = f (x )
183
195
self .assertTrue (y .shape == shape + (200 ,))
184
196
185
197
y2 = f (bm .as_jax (bm .random .random (shape + (100 ,)) < 0.1 , dtype = float ))
186
198
self .assertTrue (y2 .shape == shape + (200 ,))
199
+
200
+ conn_matrix = f .get_conn_matrix ()
201
+ self .assertTrue (bm .allclose (y , x @ conn_matrix .T ))
187
202
bm .clear_buffer_memory ()
188
203
189
204
@parameterized .product (
@@ -195,11 +210,15 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape):
195
210
def test_EventJitFPUniformLinear (self , prob , w_low , w_high , shape ):
196
211
bm .random .seed ()
197
212
f = bp .dnn .EventJitFPUniformLinear (100 , 200 , prob , w_low , w_high , seed = 123 )
198
- y = f (bm .random .random (shape + (100 ,)) < 0.1 )
213
+ x = bm .random .random (shape + (100 ,)) < 0.1
214
+ y = f (x )
199
215
self .assertTrue (y .shape == shape + (200 ,))
200
216
201
217
y2 = f (bm .as_jax (bm .random .random (shape + (100 ,)) < 0.1 , dtype = float ))
202
218
self .assertTrue (y2 .shape == shape + (200 ,))
219
+
220
+ conn_matrix = f .get_conn_matrix ()
221
+ self .assertTrue (bm .allclose (y , x @ conn_matrix .T ))
203
222
bm .clear_buffer_memory ()
204
223
205
224
@parameterized .product (
@@ -211,11 +230,15 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
211
230
def test_EventJitFPNormalLinear (self , prob , w_mu , w_sigma , shape ):
212
231
bm .random .seed ()
213
232
f = bp .dnn .EventJitFPNormalLinear (100 , 200 , prob , w_mu , w_sigma , seed = 123 )
214
- y = f (bm .random .random (shape + (100 ,)) < 0.1 )
233
+ x = bm .random .random (shape + (100 ,)) < 0.1
234
+ y = f (x )
215
235
self .assertTrue (y .shape == shape + (200 ,))
216
236
217
237
y2 = f (bm .as_jax (bm .random .random (shape + (100 ,)) < 0.1 , dtype = float ))
218
238
self .assertTrue (y2 .shape == shape + (200 ,))
239
+
240
+ conn_matrix = f .get_conn_matrix ()
241
+ self .assertTrue (bm .allclose (y , x @ conn_matrix .T ))
219
242
bm .clear_buffer_memory ()
220
243
221
244
0 commit comments