Skip to content

Commit 6d35fa0

Browse files
author
User
committed
update
1 parent 4e97d48 commit 6d35fa0

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

ann_class2/momentum.py

+3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def main():
5656
losses_batch = []
5757
errors_batch = []
5858
for i in range(max_iter):
59+
Xtrain, Ytrain, Ytrain_ind = shuffle(Xtrain, Ytrain, Ytrain_ind)
5960
for j in range(n_batches):
6061
Xbatch = Xtrain[j*batch_sz:(j*batch_sz + batch_sz),]
6162
Ybatch = Ytrain_ind[j*batch_sz:(j*batch_sz + batch_sz),]
@@ -100,6 +101,7 @@ def main():
100101
dW1 = 0
101102
db1 = 0
102103
for i in range(max_iter):
104+
Xtrain, Ytrain, Ytrain_ind = shuffle(Xtrain, Ytrain, Ytrain_ind)
103105
for j in range(n_batches):
104106
Xbatch = Xtrain[j*batch_sz:(j*batch_sz + batch_sz),]
105107
Ybatch = Ytrain_ind[j*batch_sz:(j*batch_sz + batch_sz),]
@@ -151,6 +153,7 @@ def main():
151153
vW1 = 0
152154
vb1 = 0
153155
for i in range(max_iter):
156+
Xtrain, Ytrain, Ytrain_ind = shuffle(Xtrain, Ytrain, Ytrain_ind)
154157
for j in range(n_batches):
155158
Xbatch = Xtrain[j*batch_sz:(j*batch_sz + batch_sz),]
156159
Ybatch = Ytrain_ind[j*batch_sz:(j*batch_sz + batch_sz),]

0 commit comments

Comments
 (0)