Skip to content

Commit 2ef1672

Browse files
authored
Merge pull request tensorflow#21 from karllessard/upgrade-0.3.1
Update examples to TF Java 0.3.1
2 parents c3f6dbf + f0452dd commit 2ef1672

File tree

7 files changed

+79
-79
lines changed

7 files changed

+79
-79
lines changed

tensorflow-examples/pom.xml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<modelVersion>4.0.0</modelVersion>
33
<groupId>org.tensorflow.model</groupId>
44
<artifactId>tensorflow-examples</artifactId>
5-
<version>0.1.0-SNAPSHOT</version>
5+
<version>0.3.1-SNAPSHOT</version>
66

77
<name>TensorFlow Examples</name>
88
<description>A suite of executable examples using TensorFlow Java</description>
@@ -18,12 +18,12 @@
1818
<dependency>
1919
<groupId>org.tensorflow</groupId>
2020
<artifactId>tensorflow-core-platform</artifactId>
21-
<version>0.2.0</version>
21+
<version>0.3.1</version>
2222
</dependency>
2323
<dependency>
2424
<groupId>org.tensorflow</groupId>
2525
<artifactId>tensorflow-framework</artifactId>
26-
<version>0.2.0</version>
26+
<version>0.3.1</version>
2727
</dependency>
2828
</dependencies>
2929

tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,22 @@ public static Graph build(String optimizerName) {
8888
Ops tf = Ops.create(graph);
8989

9090
// Inputs
91-
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE,
91+
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.class,
9292
Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE)));
9393
Reshape<TUint8> input_reshaped = tf
9494
.reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS));
95-
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.DTYPE);
95+
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.class);
9696

9797
// Scaling the features
9898
Constant<TFloat32> centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f);
9999
Constant<TFloat32> scalingFactor = tf.constant((float) PIXEL_DEPTH);
100100
Operand<TFloat32> scaledInput = tf.math
101-
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor),
101+
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.class), centeringFactor),
102102
scalingFactor);
103103

104104
// First conv layer
105105
Variable<TFloat32> conv1Weights = tf.variable(tf.math.mul(tf.random
106-
.truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE,
106+
.truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.class,
107107
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
108108
Conv2d<TFloat32> conv1 = tf.nn
109109
.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
@@ -118,7 +118,7 @@ public static Graph build(String optimizerName) {
118118

119119
// Second conv layer
120120
Variable<TFloat32> conv2Weights = tf.variable(tf.math.mul(tf.random
121-
.truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.DTYPE,
121+
.truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.class,
122122
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
123123
Conv2d<TFloat32> conv2 = tf.nn
124124
.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
@@ -138,7 +138,7 @@ public static Graph build(String optimizerName) {
138138

139139
// Fully connected layer
140140
Variable<TFloat32> fc1Weights = tf.variable(tf.math.mul(tf.random
141-
.truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE,
141+
.truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.class,
142142
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
143143
Variable<TFloat32> fc1Biases = tf
144144
.variable(tf.fill(tf.array(new int[]{512}), tf.constant(0.1f)));
@@ -147,7 +147,7 @@ public static Graph build(String optimizerName) {
147147

148148
// Softmax layer
149149
Variable<TFloat32> fc2Weights = tf.variable(tf.math.mul(tf.random
150-
.truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.DTYPE,
150+
.truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.class,
151151
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
152152
Variable<TFloat32> fc2Biases = tf
153153
.variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f)));
@@ -214,17 +214,17 @@ public static void train(Session session, int epochs, int minibatchSize, MnistDa
214214
// Train the model
215215
for (int i = 0; i < epochs; i++) {
216216
for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) {
217-
try (Tensor<TUint8> batchImages = TUint8.tensorOf(trainingBatch.images());
218-
Tensor<TUint8> batchLabels = TUint8.tensorOf(trainingBatch.labels());
219-
Tensor<TFloat32> loss = session.runner()
217+
try (TUint8 batchImages = TUint8.tensorOf(trainingBatch.images());
218+
TUint8 batchLabels = TUint8.tensorOf(trainingBatch.labels());
219+
TFloat32 loss = (TFloat32)session.runner()
220220
.feed(TARGET, batchLabels)
221221
.feed(INPUT_NAME, batchImages)
222222
.addTarget(TRAIN)
223223
.fetch(TRAINING_LOSS)
224-
.run().get(0).expect(TFloat32.DTYPE)) {
224+
.run().get(0)) {
225225
if (interval % 100 == 0) {
226226
logger.log(Level.INFO,
227-
"Iteration = " + interval + ", training loss = " + loss.data().getFloat());
227+
"Iteration = " + interval + ", training loss = " + loss.getFloat());
228228
}
229229
}
230230
interval++;
@@ -237,17 +237,17 @@ public static void test(Session session, int minibatchSize, MnistDataset dataset
237237
int[][] confusionMatrix = new int[10][10];
238238

239239
for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) {
240-
try (Tensor<TUint8> transformedInput = TUint8.tensorOf(trainingBatch.images());
241-
Tensor<TFloat32> outputTensor = session.runner()
240+
try (TUint8 transformedInput = TUint8.tensorOf(trainingBatch.images());
241+
TFloat32 outputTensor = (TFloat32)session.runner()
242242
.feed(INPUT_NAME, transformedInput)
243-
.fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) {
243+
.fetch(OUTPUT_NAME).run().get(0)) {
244244

245245
ByteNdArray labelBatch = trainingBatch.labels();
246246
for (int k = 0; k < labelBatch.shape().size(0); k++) {
247247
byte trueLabel = labelBatch.getByte(k);
248248
int predLabel;
249249

250-
predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all()));
250+
predLabel = argmax(outputTensor.slice(Indices.at(k), Indices.all()));
251251
if (predLabel == trueLabel) {
252252
correctCount++;
253253
}

tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,17 @@ public static Graph compile() {
8383
Ops tf = Ops.create(graph);
8484

8585
// Inputs
86-
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE,
86+
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.class,
8787
Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE)));
8888
Reshape<TUint8> input_reshaped = tf
8989
.reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS));
90-
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.DTYPE);
90+
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.class);
9191

9292
// Scaling the features
9393
Constant<TFloat32> centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f);
9494
Constant<TFloat32> scalingFactor = tf.constant((float) PIXEL_DEPTH);
9595
Operand<TFloat32> scaledInput = tf.math
96-
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor),
96+
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.class), centeringFactor),
9797
scalingFactor);
9898

9999
Relu<TFloat32> relu1 = vggConv2DLayer("1", tf, scaledInput, new int[]{3, 3, NUM_CHANNELS, 32}, 32);
@@ -137,7 +137,7 @@ public static Add<TFloat32> buildFCLayersAndRegularization(Ops tf, Placeholder<T
137137
int[] fcWeightShape = {256, fcBiasShape};
138138

139139
Variable<TFloat32> fc1Weights = tf.variable(tf.math.mul(tf.random
140-
.truncatedNormal(tf.array(fcWeightShape), TFloat32.DTYPE,
140+
.truncatedNormal(tf.array(fcWeightShape), TFloat32.class,
141141
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
142142
Variable<TFloat32> fc1Biases = tf
143143
.variable(tf.fill(tf.array(new int[]{fcBiasShape}), tf.constant(0.1f)));
@@ -146,7 +146,7 @@ public static Add<TFloat32> buildFCLayersAndRegularization(Ops tf, Placeholder<T
146146

147147
// Softmax layer
148148
Variable<TFloat32> fc2Weights = tf.variable(tf.math.mul(tf.random
149-
.truncatedNormal(tf.array(fcBiasShape, NUM_LABELS), TFloat32.DTYPE,
149+
.truncatedNormal(tf.array(fcBiasShape, NUM_LABELS), TFloat32.class,
150150
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
151151
Variable<TFloat32> fc2Biases = tf
152152
.variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f)));
@@ -183,7 +183,7 @@ public static MaxPool<TFloat32> vggMaxPool(Ops tf, Relu<TFloat32> relu1) {
183183

184184
public static Relu<TFloat32> vggConv2DLayer(String layerName, Ops tf, Operand<TFloat32> scaledInput, int[] convWeightsL1Shape, int convBiasL1Shape) {
185185
Variable<TFloat32> conv1Weights = tf.withName("conv2d_" + layerName).variable(tf.math.mul(tf.random
186-
.truncatedNormal(tf.array(convWeightsL1Shape), TFloat32.DTYPE,
186+
.truncatedNormal(tf.array(convWeightsL1Shape), TFloat32.class,
187187
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
188188
Conv2d<TFloat32> conv = tf.nn
189189
.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
@@ -201,17 +201,17 @@ public void train(MnistDataset dataset, int epochs, int minibatchSize) {
201201
// Train the model
202202
for (int i = 0; i < epochs; i++) {
203203
for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) {
204-
try (Tensor<TUint8> batchImages = TUint8.tensorOf(trainingBatch.images());
205-
Tensor<TUint8> batchLabels = TUint8.tensorOf(trainingBatch.labels());
206-
Tensor<TFloat32> loss = session.runner()
204+
try (TUint8 batchImages = TUint8.tensorOf(trainingBatch.images());
205+
TUint8 batchLabels = TUint8.tensorOf(trainingBatch.labels());
206+
TFloat32 loss = (TFloat32)session.runner()
207207
.feed(TARGET, batchLabels)
208208
.feed(INPUT_NAME, batchImages)
209209
.addTarget(TRAIN)
210210
.fetch(TRAINING_LOSS)
211-
.run().get(0).expect(TFloat32.DTYPE)) {
211+
.run().get(0)) {
212212

213213
logger.log(Level.INFO,
214-
"Iteration = " + interval + ", training loss = " + loss.data().getFloat());
214+
"Iteration = " + interval + ", training loss = " + loss.getFloat());
215215

216216
}
217217
interval++;
@@ -224,17 +224,17 @@ public void test(MnistDataset dataset, int minibatchSize) {
224224
int[][] confusionMatrix = new int[10][10];
225225

226226
for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) {
227-
try (Tensor<TUint8> transformedInput = TUint8.tensorOf(trainingBatch.images());
228-
Tensor<TFloat32> outputTensor = session.runner()
227+
try (TUint8 transformedInput = TUint8.tensorOf(trainingBatch.images());
228+
TFloat32 outputTensor = (TFloat32)session.runner()
229229
.feed(INPUT_NAME, transformedInput)
230-
.fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) {
230+
.fetch(OUTPUT_NAME).run().get(0)) {
231231

232232
ByteNdArray labelBatch = trainingBatch.labels();
233233
for (int k = 0; k < labelBatch.shape().size(0); k++) {
234234
byte trueLabel = labelBatch.getByte(k);
235235
int predLabel;
236236

237-
predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all()));
237+
predLabel = argmax(outputTensor.slice(Indices.at(k), Indices.all()));
238238
if (predLabel == trueLabel) {
239239
correctCount++;
240240
}

tensorflow-examples/src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import java.io.IOException;
2828
import java.util.zip.GZIPInputStream;
2929

30-
import static org.tensorflow.ndarray.index.Indices.from;
31-
import static org.tensorflow.ndarray.index.Indices.to;
30+
import static org.tensorflow.ndarray.index.Indices.sliceFrom;
31+
import static org.tensorflow.ndarray.index.Indices.sliceTo;
3232

3333
/** Common loader and data preprocessor for MNIST and FashionMNIST datasets. */
3434
public class MnistDataset {
@@ -44,10 +44,10 @@ public static MnistDataset create(int validationSize, String trainingImagesArchi
4444

4545
if (validationSize > 0) {
4646
return new MnistDataset(
47-
trainingImages.slice(from(validationSize)),
48-
trainingLabels.slice(from(validationSize)),
49-
trainingImages.slice(to(validationSize)),
50-
trainingLabels.slice(to(validationSize)),
47+
trainingImages.slice(sliceFrom(validationSize)),
48+
trainingLabels.slice(sliceFrom(validationSize)),
49+
trainingImages.slice(sliceTo(validationSize)),
50+
trainingLabels.slice(sliceTo(validationSize)),
5151
testImages,
5252
testLabels
5353
);
@@ -137,6 +137,6 @@ private static ByteNdArray readArchive(String archiveName) throws IOException {
137137
}
138138
byte[] bytes = new byte[size];
139139
archiveStream.readFully(bytes);
140-
return NdArrays.wrap(DataBuffers.of(bytes, true, false), Shape.of(dimSizes));
140+
return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, true, false));
141141
}
142142
}

tensorflow-examples/src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ public void run() {
5656
Ops tf = Ops.create(graph);
5757

5858
// Create placeholders and variables, which should fit batches of an unknown number of images
59-
Placeholder<TFloat32> images = tf.placeholder(TFloat32.DTYPE);
60-
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.DTYPE);
59+
Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
60+
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);
6161

6262
// Create weights with an initial value of 0
6363
Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
64-
Variable<TFloat32> weights = tf.variable(weightShape, TFloat32.DTYPE);
64+
Variable<TFloat32> weights = tf.variable(weightShape, TFloat32.class);
6565
tf.initAdd(tf.assign(weights, tf.zerosLike(weights)));
6666

6767
// Create biases with an initial value of 0
6868
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
69-
Variable<TFloat32> biases = tf.variable(biasShape, TFloat32.DTYPE);
69+
Variable<TFloat32> biases = tf.variable(biasShape, TFloat32.class);
7070
tf.initAdd(tf.assign(biases, tf.zerosLike(biases)));
7171

7272
// Register all variable initializers for single execution
@@ -98,7 +98,7 @@ public void run() {
9898
// Compute the accuracy of the model
9999
Operand<TInt64> predicted = tf.math.argMax(softmax, tf.constant(1));
100100
Operand<TInt64> expected = tf.math.argMax(labels, tf.constant(1));
101-
Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.DTYPE), tf.array(0));
101+
Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0));
102102

103103
// Run the graph
104104
try (Session session = new Session(graph)) {
@@ -108,8 +108,8 @@ public void run() {
108108

109109
// Train the model
110110
for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
111-
try (Tensor<TFloat32> batchImages = preprocessImages(trainingBatch.images());
112-
Tensor<TFloat32> batchLabels = preprocessLabels(trainingBatch.labels())) {
111+
try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
112+
TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
113113
session.runner()
114114
.addTarget(minimize)
115115
.feed(images.asOutput(), batchImages)
@@ -120,16 +120,15 @@ public void run() {
120120

121121
// Test the model
122122
ImageBatch testBatch = dataset.testBatch();
123-
try (Tensor<TFloat32> testImages = preprocessImages(testBatch.images());
124-
Tensor<TFloat32> testLabels = preprocessLabels(testBatch.labels());
125-
Tensor<TFloat32> accuracyValue = session.runner()
123+
try (TFloat32 testImages = preprocessImages(testBatch.images());
124+
TFloat32 testLabels = preprocessLabels(testBatch.labels());
125+
TFloat32 accuracyValue = (TFloat32)session.runner()
126126
.fetch(accuracy)
127127
.feed(images.asOutput(), testImages)
128128
.feed(labels.asOutput(), testLabels)
129129
.run()
130-
.get(0)
131-
.expect(TFloat32.DTYPE)) {
132-
System.out.println("Accuracy: " + accuracyValue.data().getFloat());
130+
.get(0)) {
131+
System.out.println("Accuracy: " + accuracyValue.getFloat());
133132
}
134133
}
135134
}
@@ -138,21 +137,21 @@ public void run() {
138137
private static final int TRAINING_BATCH_SIZE = 100;
139138
private static final float LEARNING_RATE = 0.2f;
140139

141-
private static Tensor<TFloat32> preprocessImages(ByteNdArray rawImages) {
140+
private static TFloat32 preprocessImages(ByteNdArray rawImages) {
142141
Ops tf = Ops.create();
143142

144143
// Flatten images in a single dimension and normalize their pixels as floats.
145144
long imageSize = rawImages.get(0).shape().size();
146145
return tf.math.div(
147146
tf.reshape(
148-
tf.dtypes.cast(tf.constant(rawImages), TFloat32.DTYPE),
147+
tf.dtypes.cast(tf.constant(rawImages), TFloat32.class),
149148
tf.array(-1L, imageSize)
150149
),
151150
tf.constant(255.0f)
152151
).asTensor();
153152
}
154153

155-
private static Tensor<TFloat32> preprocessLabels(ByteNdArray rawLabels) {
154+
private static TFloat32 preprocessLabels(ByteNdArray rawLabels) {
156155
Ops tf = Ops.create();
157156

158157
// Map labels to one hot vectors where only the expected predictions as a value of 1.0

0 commit comments

Comments
 (0)