-
Notifications
You must be signed in to change notification settings - Fork 1
/
runNNet2.py
46 lines (32 loc) · 936 Bytes
/
runNNet2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import sgd
import nnet
import dataLoader as dl
import numpy as np
import gnumpy as gp
import preprocess as pc
gp.board_id_to_use = 1
def run():
print "Loading data..."
# load training data
trainImages,trainLabels=dl.load_mnist_train()
imDim = trainImages.shape[0]
inputDim = 50
outputDim = 10
layerSizes = [16]*2
trainImages = trainImages.reshape(imDim**2,-1)
pcer = pc.Preprocess()
pcer.computePCA(trainImages)
whitenedTrain = pcer.whiten(trainImages, inputDim)
minibatch = whitenedTrain.shape[1]
print "minibatch size: %d" % (minibatch)
epochs = 10000
stepSize = 1e-2
nn = nnet.NNet(inputDim,outputDim,layerSizes,minibatch)
nn.initParams()
SGD = sgd.SGD(nn,alpha=stepSize,minibatch=minibatch)
for e in range(epochs):
print "Running epoch %d"%e
SGD.run(whitenedTrain,trainLabels)
SGD.dumptrace()
if __name__=='__main__':
run()