Skip to content

Commit 3513937

Browse files
committed
Fix rand function (instead of generating integral and fractional, generate a number in [0, MAX_INT] and multiply by max); adjust tests to match
1 parent 81e7ec1 commit 3513937

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

mnist-clojure.iml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
<exclude-output />
77
<content url="file://$MODULE_DIR$">
88
<sourceFolder url="file://$MODULE_DIR$/resources" isTestSource="false" />
9-
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
109
<sourceFolder url="file://$MODULE_DIR$/dev-resources" isTestSource="false" />
10+
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
1111
<sourceFolder url="file://$MODULE_DIR$/test" isTestSource="true" />
1212
<excludeFolder url="file://$MODULE_DIR$/target/default" />
1313
</content>

src/mnist_clojure/core.clj

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030
(+ 256 byte)
3131
byte))
3232

33-
(defn read-mnist-file [file-name]
33+
(defn read-mnist-file
3434
"Reads the idx[x]-ubyte format and parses it into a byte array"
35+
[file-name]
3536
(let [file (io/file file-name)
3637
b-array (byte-array (.length file))]
3738
(with-open [stream (io/input-stream file)]
@@ -104,8 +105,9 @@
104105
"Creates a list of specified length with random decimal values between 0 to max"
105106
{:test (fn []
106107
(is= (rand-d-list 4 1 0.4)
107-
[1961823115700386051
108-
'(0.1430590959000676 0.8178222714074991 0.5044600700514671 0.01660157088963388)]))}
108+
[-9136436700791295257
109+
[0.20178402802058684 0.02002609726974093
110+
0.006640628355853552 1.8626451500983188E-10]]))}
109111
[length seed max]
110112
(reduce (fn [[seed rand-vals] _]
111113
(let [[next-seed rand-val] (rand seed max)]
@@ -119,7 +121,10 @@
119121
(is= (rand-list-max-mag 4 2 0)
120122
[-9197343212719499864 `(0 0 0 0)])
121123
(is= (rand-list-max-mag 3 4 23)
122-
[-7068052242903947273 `(-6.004336956424794 13.63314510538855 -18.93359371783845)]))}
124+
[77986490623247743
125+
[-22.18838978334721
126+
-19.945311020568624
127+
-22.999999914318323]]))}
123128
[length seed max]
124129
(let [[new-seed rand-list] (rand-d-list length seed (* 2 max))]
125130
[new-seed (map (partial + (- max)) rand-list)]))
@@ -128,11 +133,18 @@
128133
"Initializes one layer of the neural network. Note the + 1 adjusts for the bias"
129134
; y = Wx
130135
{:test (fn []
131-
(is= (initialize-layer 1 3 1 0) [-8728512804673154413 `((0 0) (0 0) (0 0))])
136+
(is= (initialize-layer 1 3 1 0)
137+
[-8728512804673154413 [[0 0] [0 0] [0 0]]])
132138
(is= (initialize-layer 3 2 24 1)
133-
[7786394753034826687
134-
'((-0.32256568051994106 0.5576858765248609 0.5806762038640101 -0.8371181152002505)
135-
(0.6701701863995613 0.3397951933274954 -0.48321265703216776 -0.6015623093589966))]))}
139+
[2665986749794895764
140+
[[0.34034037279912277
141+
-0.24479991162419323
142+
-0.32040961334500895
143+
-0.38945634541542096]
144+
[0.03357468593566448
145+
-0.9542951490517217
146+
-0.20312461871799303
147+
-0.9999999776482582]]]))}
136148
[input-count output-count seed max]
137149
(reduce (fn [[seed rand-lists] _]
138150
(let [[new-seed rand-list] (rand-list-max-mag (+ input-count 1) seed max)]
@@ -246,6 +258,6 @@
246258
"resources/train-labels.idx1-ubyte" 0.2 10 1))))
247259

248260
(comment (time (train-and-test "resources/train-images.idx3-ubyte"
249-
"resources/train-labels.idx1-ubyte"
250-
"resources/test-images.idx3-ubyte"
251-
"resources/test-labels.idx1-ubyte")))
261+
"resources/train-labels.idx1-ubyte"
262+
"resources/test-images.idx3-ubyte"
263+
"resources/test-labels.idx1-ubyte")))

src/mnist_clojure/random.clj

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,40 @@
33
[ysera.test :refer [is is= is-not]])
44
(:refer-clojure :exclude [rand rand-int]))
55

6+
(defn get-new-seed
7+
"Returns a new seed"
8+
{:test (fn []
9+
(is= (get-new-seed 0) 1)
10+
(is= (get-new-seed 1) 35651602))}
11+
[seed]
12+
(let [[new-seed _] (get-random-int seed 1)]
13+
new-seed))
14+
615
(defn rand-int
716
"Returns [new-seed random-int] such that 0 <= rand-int < max."
817
{:test (fn []
18+
(is= (rand-int 1 1) [35651602 0])
19+
(is= (rand-int 1 0.1) [35651602 0])
920
(is= (rand-int 1 3) [35651602 1])
10-
(is= (rand-int 3158 13984) [108024351031 3158])
1121
; rand-int always returns 0 without errors no matter the seed.
1222
(is= (rand-int 3 0) [106954804 0])
13-
(is= (rand-int 1513123 0) [49270036527376 0]))}
23+
(is= (rand-int -3 0) [94371886 0])
24+
)}
1425
[seed max]
1526
(if (< max 1)
16-
(let [[new-seed _] (get-random-int seed 1)]
17-
[new-seed 0])
27+
[(get-new-seed seed) 0]
1828
(get-random-int seed max)))
1929

2030
(defn rand
2131
"Returns [new-seed rand] such that 0 <= rand < max."
2232
{:test (fn []
23-
(is= (rand 3 4) [3390894109721333 3.049804711737579])
24-
(is= (rand 1023193 9) [5555018649264297881 1.3268604484977482])
25-
(is= (rand 1319 0) [46890243928 0])
26-
(is= (rand 19220393 0) [684066761437437 0]))}
33+
(is= (rand 1 0) [35651602 0])
34+
(is= (rand 1 1) [35651602 4.656612875245797E-10])
35+
(is= (rand 1 0.5) [35651602 2.3283064376228985E-10])
36+
(is= (rand 1 0.1) [35651602 4.656612875245797E-11]))}
2737
[seed max]
28-
(if (= max 0)
29-
(rand-int seed max)
30-
(let [[new-seed random-int] (rand-int seed max)
31-
[nn-seed new-random-int] (rand-int new-seed Integer/MAX_VALUE)]
32-
(if (= new-random-int 0)
33-
[nn-seed random-int]
34-
[nn-seed (+ random-int (double (/ new-random-int Integer/MAX_VALUE)))]))))
35-
36-
38+
(if (<= max 0)
39+
[(get-new-seed seed) 0]
40+
(let [[seed n] (rand-int seed Integer/MAX_VALUE)
41+
weight (double (/ n Integer/MAX_VALUE))]
42+
[seed (* max weight)])))

0 commit comments

Comments
 (0)