@@ -56,6 +56,7 @@ def main():
56
56
losses_batch = []
57
57
errors_batch = []
58
58
for i in range (max_iter ):
59
+ Xtrain , Ytrain , Ytrain_ind = shuffle (Xtrain , Ytrain , Ytrain_ind )
59
60
for j in range (n_batches ):
60
61
Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
61
62
Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
@@ -100,6 +101,7 @@ def main():
100
101
dW1 = 0
101
102
db1 = 0
102
103
for i in range (max_iter ):
104
+ Xtrain , Ytrain , Ytrain_ind = shuffle (Xtrain , Ytrain , Ytrain_ind )
103
105
for j in range (n_batches ):
104
106
Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
105
107
Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
@@ -151,6 +153,7 @@ def main():
151
153
vW1 = 0
152
154
vb1 = 0
153
155
for i in range (max_iter ):
156
+ Xtrain , Ytrain , Ytrain_ind = shuffle (Xtrain , Ytrain , Ytrain_ind )
154
157
for j in range (n_batches ):
155
158
Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
156
159
Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
0 commit comments