@@ -210,15 +210,15 @@ public static void train(Session session, int epochs, int minibatchSize, MnistDa
210
210
for (ImageBatch trainingBatch : dataset .trainingBatches (minibatchSize )) {
211
211
try (Tensor <TUint8 > batchImages = TUint8 .tensorOf (trainingBatch .images ());
212
212
Tensor <TUint8 > batchLabels = TUint8 .tensorOf (trainingBatch .labels ());
213
- Tensor <? > loss = session .runner ()
213
+ Tensor <TFloat32 > loss = session .runner ()
214
214
.feed (TARGET , batchLabels )
215
215
.feed (INPUT_NAME , batchImages )
216
216
.addTarget (TRAIN )
217
217
.fetch (TRAINING_LOSS )
218
- .run ().get (0 )) {
218
+ .run ().get (0 ). expect ( TFloat32 . DTYPE ) ) {
219
219
if (interval % 100 == 0 ) {
220
220
logger .log (Level .INFO ,
221
- "Iteration = " + interval + ", training loss = " + loss .floatValue ());
221
+ "Iteration = " + interval + ", training loss = " + loss .data (). getFloat ());
222
222
}
223
223
}
224
224
interval ++;
@@ -227,30 +227,27 @@ public static void train(Session session, int epochs, int minibatchSize, MnistDa
227
227
}
228
228
229
229
public static void test (Session session , int minibatchSize , MnistDataset dataset ) {
230
- TFloat32 prediction ;
231
-
232
230
int correctCount = 0 ;
233
231
int [][] confusionMatrix = new int [10 ][10 ];
234
232
235
233
for (ImageBatch trainingBatch : dataset .testBatches (minibatchSize )) {
236
234
try (Tensor <TUint8 > transformedInput = TUint8 .tensorOf (trainingBatch .images ());
237
- Tensor <? > outputTensor = session .runner ()
235
+ Tensor <TFloat32 > outputTensor = session .runner ()
238
236
.feed (INPUT_NAME , transformedInput )
239
- .fetch (OUTPUT_NAME ).run ().get (0 )) {
240
- prediction = (TFloat32 ) outputTensor .data ();
241
- }
237
+ .fetch (OUTPUT_NAME ).run ().get (0 ).expect (TFloat32 .DTYPE )) {
242
238
243
- ByteNdArray labelBatch = trainingBatch .labels ();
244
- for (int k = 0 ; k < labelBatch .shape ().size (0 ); k ++) {
245
- byte trueLabel = labelBatch .getByte (k );
246
- int predLabel ;
239
+ ByteNdArray labelBatch = trainingBatch .labels ();
240
+ for (int k = 0 ; k < labelBatch .shape ().size (0 ); k ++) {
241
+ byte trueLabel = labelBatch .getByte (k );
242
+ int predLabel ;
247
243
248
- predLabel = argmax (prediction .slice (Indices .at (k ), Indices .all ()));
249
- if (predLabel == trueLabel ) {
250
- correctCount ++;
251
- }
244
+ predLabel = argmax (outputTensor . data () .slice (Indices .at (k ), Indices .all ()));
245
+ if (predLabel == trueLabel ) {
246
+ correctCount ++;
247
+ }
252
248
253
- confusionMatrix [trueLabel ][predLabel ]++;
249
+ confusionMatrix [trueLabel ][predLabel ]++;
250
+ }
254
251
}
255
252
}
256
253
0 commit comments