Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tensorflow-examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow.model</groupId>
<artifactId>tensorflow-examples</artifactId>
<version>0.1.0-SNAPSHOT</version>
<version>0.3.1-SNAPSHOT</version>

<name>TensorFlow Examples</name>
<description>A suite of executable examples using TensorFlow Java</description>
Expand All @@ -18,12 +18,12 @@
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.2.0</version>
<version>0.3.1</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-framework</artifactId>
<version>0.2.0</version>
<version>0.3.1</version>
</dependency>
</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,22 @@ public static Graph build(String optimizerName) {
Ops tf = Ops.create(graph);

// Inputs
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE,
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.class,
Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE)));
Reshape<TUint8> input_reshaped = tf
.reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS));
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.DTYPE);
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.class);

// Scaling the features
Constant<TFloat32> centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f);
Constant<TFloat32> scalingFactor = tf.constant((float) PIXEL_DEPTH);
Operand<TFloat32> scaledInput = tf.math
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor),
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.class), centeringFactor),
scalingFactor);

// First conv layer
Variable<TFloat32> conv1Weights = tf.variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE,
.truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.class,
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
Conv2d<TFloat32> conv1 = tf.nn
.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
Expand All @@ -118,7 +118,7 @@ public static Graph build(String optimizerName) {

// Second conv layer
Variable<TFloat32> conv2Weights = tf.variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.DTYPE,
.truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.class,
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
Conv2d<TFloat32> conv2 = tf.nn
.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
Expand All @@ -138,7 +138,7 @@ public static Graph build(String optimizerName) {

// Fully connected layer
Variable<TFloat32> fc1Weights = tf.variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE,
.truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.class,
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
Variable<TFloat32> fc1Biases = tf
.variable(tf.fill(tf.array(new int[]{512}), tf.constant(0.1f)));
Expand All @@ -147,7 +147,7 @@ public static Graph build(String optimizerName) {

// Softmax layer
Variable<TFloat32> fc2Weights = tf.variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.DTYPE,
.truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.class,
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
Variable<TFloat32> fc2Biases = tf
.variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f)));
Expand Down Expand Up @@ -214,17 +214,17 @@ public static void train(Session session, int epochs, int minibatchSize, MnistDa
// Train the model
for (int i = 0; i < epochs; i++) {
for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) {
try (Tensor<TUint8> batchImages = TUint8.tensorOf(trainingBatch.images());
Tensor<TUint8> batchLabels = TUint8.tensorOf(trainingBatch.labels());
Tensor<TFloat32> loss = session.runner()
try (TUint8 batchImages = TUint8.tensorOf(trainingBatch.images());
TUint8 batchLabels = TUint8.tensorOf(trainingBatch.labels());
TFloat32 loss = (TFloat32)session.runner()
.feed(TARGET, batchLabels)
.feed(INPUT_NAME, batchImages)
.addTarget(TRAIN)
.fetch(TRAINING_LOSS)
.run().get(0).expect(TFloat32.DTYPE)) {
.run().get(0)) {
if (interval % 100 == 0) {
logger.log(Level.INFO,
"Iteration = " + interval + ", training loss = " + loss.data().getFloat());
"Iteration = " + interval + ", training loss = " + loss.getFloat());
}
}
interval++;
Expand All @@ -237,17 +237,17 @@ public static void test(Session session, int minibatchSize, MnistDataset dataset
int[][] confusionMatrix = new int[10][10];

for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) {
try (Tensor<TUint8> transformedInput = TUint8.tensorOf(trainingBatch.images());
Tensor<TFloat32> outputTensor = session.runner()
try (TUint8 transformedInput = TUint8.tensorOf(trainingBatch.images());
TFloat32 outputTensor = (TFloat32)session.runner()
.feed(INPUT_NAME, transformedInput)
.fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) {
.fetch(OUTPUT_NAME).run().get(0)) {

ByteNdArray labelBatch = trainingBatch.labels();
for (int k = 0; k < labelBatch.shape().size(0); k++) {
byte trueLabel = labelBatch.getByte(k);
int predLabel;

predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all()));
predLabel = argmax(outputTensor.slice(Indices.at(k), Indices.all()));
if (predLabel == trueLabel) {
correctCount++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,17 @@ public static Graph compile() {
Ops tf = Ops.create(graph);

// Inputs
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE,
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.class,
Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE)));
Reshape<TUint8> input_reshaped = tf
.reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS));
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.DTYPE);
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.class);

// Scaling the features
Constant<TFloat32> centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f);
Constant<TFloat32> scalingFactor = tf.constant((float) PIXEL_DEPTH);
Operand<TFloat32> scaledInput = tf.math
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor),
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.class), centeringFactor),
scalingFactor);

Relu<TFloat32> relu1 = vggConv2DLayer("1", tf, scaledInput, new int[]{3, 3, NUM_CHANNELS, 32}, 32);
Expand Down Expand Up @@ -137,7 +137,7 @@ public static Add<TFloat32> buildFCLayersAndRegularization(Ops tf, Placeholder<T
int[] fcWeightShape = {256, fcBiasShape};

Variable<TFloat32> fc1Weights = tf.variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(fcWeightShape), TFloat32.DTYPE,
.truncatedNormal(tf.array(fcWeightShape), TFloat32.class,
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
Variable<TFloat32> fc1Biases = tf
.variable(tf.fill(tf.array(new int[]{fcBiasShape}), tf.constant(0.1f)));
Expand All @@ -146,7 +146,7 @@ public static Add<TFloat32> buildFCLayersAndRegularization(Ops tf, Placeholder<T

// Softmax layer
Variable<TFloat32> fc2Weights = tf.variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(fcBiasShape, NUM_LABELS), TFloat32.DTYPE,
.truncatedNormal(tf.array(fcBiasShape, NUM_LABELS), TFloat32.class,
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
Variable<TFloat32> fc2Biases = tf
.variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f)));
Expand Down Expand Up @@ -183,7 +183,7 @@ public static MaxPool<TFloat32> vggMaxPool(Ops tf, Relu<TFloat32> relu1) {

public static Relu<TFloat32> vggConv2DLayer(String layerName, Ops tf, Operand<TFloat32> scaledInput, int[] convWeightsL1Shape, int convBiasL1Shape) {
Variable<TFloat32> conv1Weights = tf.withName("conv2d_" + layerName).variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(convWeightsL1Shape), TFloat32.DTYPE,
.truncatedNormal(tf.array(convWeightsL1Shape), TFloat32.class,
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
Conv2d<TFloat32> conv = tf.nn
.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
Expand All @@ -201,17 +201,17 @@ public void train(MnistDataset dataset, int epochs, int minibatchSize) {
// Train the model
for (int i = 0; i < epochs; i++) {
for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) {
try (Tensor<TUint8> batchImages = TUint8.tensorOf(trainingBatch.images());
Tensor<TUint8> batchLabels = TUint8.tensorOf(trainingBatch.labels());
Tensor<TFloat32> loss = session.runner()
try (TUint8 batchImages = TUint8.tensorOf(trainingBatch.images());
TUint8 batchLabels = TUint8.tensorOf(trainingBatch.labels());
TFloat32 loss = (TFloat32)session.runner()
.feed(TARGET, batchLabels)
.feed(INPUT_NAME, batchImages)
.addTarget(TRAIN)
.fetch(TRAINING_LOSS)
.run().get(0).expect(TFloat32.DTYPE)) {
.run().get(0)) {

logger.log(Level.INFO,
"Iteration = " + interval + ", training loss = " + loss.data().getFloat());
"Iteration = " + interval + ", training loss = " + loss.getFloat());

}
interval++;
Expand All @@ -224,17 +224,17 @@ public void test(MnistDataset dataset, int minibatchSize) {
int[][] confusionMatrix = new int[10][10];

for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) {
try (Tensor<TUint8> transformedInput = TUint8.tensorOf(trainingBatch.images());
Tensor<TFloat32> outputTensor = session.runner()
try (TUint8 transformedInput = TUint8.tensorOf(trainingBatch.images());
TFloat32 outputTensor = (TFloat32)session.runner()
.feed(INPUT_NAME, transformedInput)
.fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) {
.fetch(OUTPUT_NAME).run().get(0)) {

ByteNdArray labelBatch = trainingBatch.labels();
for (int k = 0; k < labelBatch.shape().size(0); k++) {
byte trueLabel = labelBatch.getByte(k);
int predLabel;

predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all()));
predLabel = argmax(outputTensor.slice(Indices.at(k), Indices.all()));
if (predLabel == trueLabel) {
correctCount++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import java.io.IOException;
import java.util.zip.GZIPInputStream;

import static org.tensorflow.ndarray.index.Indices.from;
import static org.tensorflow.ndarray.index.Indices.to;
import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.sliceTo;

/** Common loader and data preprocessor for MNIST and FashionMNIST datasets. */
public class MnistDataset {
Expand All @@ -44,10 +44,10 @@ public static MnistDataset create(int validationSize, String trainingImagesArchi

if (validationSize > 0) {
return new MnistDataset(
trainingImages.slice(from(validationSize)),
trainingLabels.slice(from(validationSize)),
trainingImages.slice(to(validationSize)),
trainingLabels.slice(to(validationSize)),
trainingImages.slice(sliceFrom(validationSize)),
trainingLabels.slice(sliceFrom(validationSize)),
trainingImages.slice(sliceTo(validationSize)),
trainingLabels.slice(sliceTo(validationSize)),
testImages,
testLabels
);
Expand Down Expand Up @@ -137,6 +137,6 @@ private static ByteNdArray readArchive(String archiveName) throws IOException {
}
byte[] bytes = new byte[size];
archiveStream.readFully(bytes);
return NdArrays.wrap(DataBuffers.of(bytes, true, false), Shape.of(dimSizes));
return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, true, false));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ public void run() {
Ops tf = Ops.create(graph);

// Create placeholders and variables, which should fit batches of an unknown number of images
Placeholder<TFloat32> images = tf.placeholder(TFloat32.DTYPE);
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.DTYPE);
Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);

// Create weights with an initial value of 0
Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
Variable<TFloat32> weights = tf.variable(weightShape, TFloat32.DTYPE);
Variable<TFloat32> weights = tf.variable(weightShape, TFloat32.class);
tf.initAdd(tf.assign(weights, tf.zerosLike(weights)));

// Create biases with an initial value of 0
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
Variable<TFloat32> biases = tf.variable(biasShape, TFloat32.DTYPE);
Variable<TFloat32> biases = tf.variable(biasShape, TFloat32.class);
tf.initAdd(tf.assign(biases, tf.zerosLike(biases)));

// Register all variable initializers for single execution
Expand Down Expand Up @@ -98,7 +98,7 @@ public void run() {
// Compute the accuracy of the model
Operand<TInt64> predicted = tf.math.argMax(softmax, tf.constant(1));
Operand<TInt64> expected = tf.math.argMax(labels, tf.constant(1));
Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.DTYPE), tf.array(0));
Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0));

// Run the graph
try (Session session = new Session(graph)) {
Expand All @@ -108,8 +108,8 @@ public void run() {

// Train the model
for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
try (Tensor<TFloat32> batchImages = preprocessImages(trainingBatch.images());
Tensor<TFloat32> batchLabels = preprocessLabels(trainingBatch.labels())) {
try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
session.runner()
.addTarget(minimize)
.feed(images.asOutput(), batchImages)
Expand All @@ -120,16 +120,15 @@ public void run() {

// Test the model
ImageBatch testBatch = dataset.testBatch();
try (Tensor<TFloat32> testImages = preprocessImages(testBatch.images());
Tensor<TFloat32> testLabels = preprocessLabels(testBatch.labels());
Tensor<TFloat32> accuracyValue = session.runner()
try (TFloat32 testImages = preprocessImages(testBatch.images());
TFloat32 testLabels = preprocessLabels(testBatch.labels());
TFloat32 accuracyValue = (TFloat32)session.runner()
.fetch(accuracy)
.feed(images.asOutput(), testImages)
.feed(labels.asOutput(), testLabels)
.run()
.get(0)
.expect(TFloat32.DTYPE)) {
System.out.println("Accuracy: " + accuracyValue.data().getFloat());
.get(0)) {
System.out.println("Accuracy: " + accuracyValue.getFloat());
}
}
}
Expand All @@ -138,21 +137,21 @@ public void run() {
private static final int TRAINING_BATCH_SIZE = 100;
private static final float LEARNING_RATE = 0.2f;

private static Tensor<TFloat32> preprocessImages(ByteNdArray rawImages) {
private static TFloat32 preprocessImages(ByteNdArray rawImages) {
Ops tf = Ops.create();

// Flatten images in a single dimension and normalize their pixels as floats.
long imageSize = rawImages.get(0).shape().size();
return tf.math.div(
tf.reshape(
tf.dtypes.cast(tf.constant(rawImages), TFloat32.DTYPE),
tf.dtypes.cast(tf.constant(rawImages), TFloat32.class),
tf.array(-1L, imageSize)
),
tf.constant(255.0f)
).asTensor();
}

private static Tensor<TFloat32> preprocessLabels(ByteNdArray rawLabels) {
private static TFloat32 preprocessLabels(ByteNdArray rawLabels) {
Ops tf = Ops.create();

// Map labels to one hot vectors where only the expected predictions as a value of 1.0
Expand Down
Loading