Skip to content

Commit 7fb3015

Browse files
committed
Update code with new API
1 parent dc62e25 commit 7fb3015

File tree

2 files changed

+16
-19
lines changed

2 files changed

+16
-19
lines changed

tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,15 @@ public static void train(Session session, int epochs, int minibatchSize, MnistDa
210210
for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) {
211211
try (Tensor<TUint8> batchImages = TUint8.tensorOf(trainingBatch.images());
212212
Tensor<TUint8> batchLabels = TUint8.tensorOf(trainingBatch.labels());
213-
Tensor<?> loss = session.runner()
213+
Tensor<TFloat32> loss = session.runner()
214214
.feed(TARGET, batchLabels)
215215
.feed(INPUT_NAME, batchImages)
216216
.addTarget(TRAIN)
217217
.fetch(TRAINING_LOSS)
218-
.run().get(0)) {
218+
.run().get(0).expect(TFloat32.DTYPE)) {
219219
if (interval % 100 == 0) {
220220
logger.log(Level.INFO,
221-
"Iteration = " + interval + ", training loss = " + loss.floatValue());
221+
"Iteration = " + interval + ", training loss = " + loss.data().getFloat());
222222
}
223223
}
224224
interval++;
@@ -227,30 +227,27 @@ public static void train(Session session, int epochs, int minibatchSize, MnistDa
227227
}
228228

229229
public static void test(Session session, int minibatchSize, MnistDataset dataset) {
230-
TFloat32 prediction;
231-
232230
int correctCount = 0;
233231
int[][] confusionMatrix = new int[10][10];
234232

235233
for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) {
236234
try (Tensor<TUint8> transformedInput = TUint8.tensorOf(trainingBatch.images());
237-
Tensor<?> outputTensor = session.runner()
235+
Tensor<TFloat32> outputTensor = session.runner()
238236
.feed(INPUT_NAME, transformedInput)
239-
.fetch(OUTPUT_NAME).run().get(0)) {
240-
prediction = (TFloat32) outputTensor.data();
241-
}
237+
.fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) {
242238

243-
ByteNdArray labelBatch = trainingBatch.labels();
244-
for (int k = 0; k < labelBatch.shape().size(0); k++) {
245-
byte trueLabel = labelBatch.getByte(k);
246-
int predLabel;
239+
ByteNdArray labelBatch = trainingBatch.labels();
240+
for (int k = 0; k < labelBatch.shape().size(0); k++) {
241+
byte trueLabel = labelBatch.getByte(k);
242+
int predLabel;
247243

248-
predLabel = argmax(prediction.slice(Indices.at(k), Indices.all()));
249-
if (predLabel == trueLabel) {
250-
correctCount++;
251-
}
244+
predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all()));
245+
if (predLabel == trueLabel) {
246+
correctCount++;
247+
}
252248

253-
confusionMatrix[trueLabel][predLabel]++;
249+
confusionMatrix[trueLabel][predLabel]++;
250+
}
254251
}
255252
}
256253

tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/MnistDataset.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,6 @@ private static ByteNdArray readArchive(String archiveName) throws IOException {
138138
}
139139
byte[] bytes = new byte[size];
140140
archiveStream.readFully(bytes);
141-
return NdArrays.wrap(DataBuffers.from(bytes, true, false), Shape.of(dimSizes));
141+
return NdArrays.wrap(DataBuffers.of(bytes, true, false), Shape.of(dimSizes));
142142
}
143143
}

0 commit comments

Comments
 (0)