Skip to content

Commit c27c760

Browse files
committed
train: Example of running a training loop in Java.
1 parent 8891b8d commit c27c760

File tree

5 files changed

+197
-0
lines changed

5 files changed

+197
-0
lines changed

train/README.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Training [TensorFlow](https://www.tensorflow.org) models in Java
2+
3+
Python is the primary language in which TensorFlow models are typically
4+
developed and trained. TensorFlow does have [bindings for other programming
5+
languages](https://www.tensorflow.org/api_docs/). These bindings have the
6+
low-level primitives that are required to build a more complete API, however,
7+
lack much of the higher-level API richness of the Python bindings, particularly
8+
for defining the model structure.
9+
10+
This file demonstrates taking a model (a TensorFlow graph) created by a Python
11+
program and running the training loop in Java (and saving the trained weights
12+
to disk).
13+
14+
## The model
15+
16+
The model is a trivial one, trying to learn the function: `f(x) = W\*x + b`,
17+
where `W` and `b` are model parameters. The training data is constructed so that
18+
the "true" value of `W` is 3 and that of `b` is 2, i.e., `f(x) = 3 * x + 2`.
19+
20+
Thus, over time, the predicted value for an input of 1, 2, and 3 should tend
21+
towards 5, 8, and 11.
22+
23+
## Quickstart
24+
25+
1. Run the training loop program in Java using:
26+
27+
```
28+
mvn compile exec:java -q -Dexec.args="graph.pb /tmp/checkpoint"
29+
```
30+
31+
Where `graph.pb` is the serialized TenosrFlow graph and `/tmp/checkpoint`
32+
is the directory from which trained weights (the checkpoint) should be
33+
loaded (if available) and saved to (after training).
34+
35+
## Generating the graph
36+
37+
The `graph.pb` file which contains the model definition, and the names of the
38+
tensors in it were generated by running `python model.py`.
39+
40+
41+
## Noteworthy
42+
43+
- The Python APIs for TensorFlow include other conveniences for training (such
44+
as `MonitoredSession` and `tf.train.Estimator`), which make it easier to
45+
configure checkpointing, evaluation loops etc. The examples here aren't that
46+
sophisticated and are focused on basic model training only.
47+
- In this example, we use placeholders and feed dictionaries to feed input,
48+
but you probably want to use the
49+
[`tf.data`](https://www.tensorflow.org/programmers_guide/datasets) API to
50+
cconstruct an input pipeline for providing training data to the model.
51+
- Not demonstrated here, but summaries for TensorBoard can also be produced by
52+
executing the summary operations.

train/graph.pb

14.1 KB
Binary file not shown.

train/model.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import tensorflow as tf
2+
3+
# Batch of input and target output (1x1 matrices)
4+
x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
5+
y = tf.placeholder(tf.float32, shape=[None, 1, 1], name='target')
6+
7+
# Trivial linear model
8+
y_ = tf.identity(tf.layers.dense(x, 1), name='output')
9+
10+
# Optimize loss
11+
loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
12+
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
13+
train_op = optimizer.minimize(loss, name='train')
14+
15+
init = tf.global_variables_initializer()
16+
17+
# tf.train.Saver.__init__ adds operations to the graph to save
18+
# and restore variables.
19+
saver_def = tf.train.Saver().as_saver_def()
20+
21+
print('Run this operation to initialize variables : ', init.name)
22+
print('Run this operation for a train step : ', train_op.name)
23+
print('Feed this tensor to set the checkpoint filename: ', saver_def.filename_tensor_name)
24+
print('Run this operation to save a checkpoint : ', saver_def.save_tensor_name)
25+
print('Run this operation to restore a checkpoint : ', saver_def.restore_op_name)
26+
27+
# Write the graph out to a file.
28+
with open('graph.pb', 'w') as f:
29+
f.write(tf.get_default_graph().as_graph_def().SerializeToString())

train/pom.xml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<project>
2+
<modelVersion>4.0.0</modelVersion>
3+
<groupId>org.myorg</groupId>
4+
<artifactId>train</artifactId>
5+
<version>1.0-SNAPSHOT</version>
6+
<properties>
7+
<exec.mainClass>Train</exec.mainClass>
8+
<!-- The sample code requires at least JDK 1.7. -->
9+
<!-- The maven compiler plugin defaults to a lower version -->
10+
<maven.compiler.source>1.7</maven.compiler.source>
11+
<maven.compiler.target>1.7</maven.compiler.target>
12+
</properties>
13+
<dependencies>
14+
<dependency>
15+
<groupId>org.tensorflow</groupId>
16+
<artifactId>tensorflow</artifactId>
17+
<version>1.4.0</version>
18+
</dependency>
19+
<dependency>
20+
<groupId>org.tensorflow</groupId>
21+
<artifactId>proto</artifactId>
22+
<version>1.4.0</version>
23+
</dependency>
24+
</dependencies>
25+
</project>

train/src/main/java/Train.java

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import java.nio.file.Files;
2+
import java.nio.file.Paths;
3+
import java.util.Arrays;
4+
import java.util.Random;
5+
import org.tensorflow.Graph;
6+
import org.tensorflow.Session;
7+
import org.tensorflow.Tensor;
8+
import org.tensorflow.Tensors;
9+
10+
public class Train {
11+
12+
public static void main(String[] args) throws Exception {
13+
if (args.length != 2) {
14+
System.err.println("Require two arguments: <graph_def_filename> <directory_for_checkpoints>");
15+
System.exit(1);
16+
}
17+
final byte[] graphDef = Files.readAllBytes(Paths.get(args[0]));
18+
final String checkpointDir = args[1];
19+
final boolean checkpointExists = Files.exists(Paths.get(checkpointDir));
20+
21+
// These names of tensors/operations in the graph (string arguments to feed(), fetch(), and
22+
// addTarget()) would have been printed out by model.py
23+
try (Graph graph = new Graph();
24+
Session sess = new Session(graph);
25+
Tensor<String> checkpointPrefix =
26+
Tensors.create(Paths.get(checkpointDir, "checkpoint").toString())) {
27+
graph.importGraphDef(graphDef);
28+
29+
// Initialize or restore.
30+
if (checkpointExists) {
31+
System.out.println("Restoring variables from checkpoint");
32+
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run();
33+
} else {
34+
System.out.println("Initializing variables");
35+
sess.runner().addTarget("init").run();
36+
}
37+
38+
System.out.println("Generating initial predictions");
39+
printPredictionsOnTestSet(sess);
40+
41+
System.out.println("Training for a few steps");
42+
final int BATCH_SIZE = 10;
43+
float inputs[][][] = new float[BATCH_SIZE][1][1];
44+
float targets[][][] = new float[BATCH_SIZE][1][1];
45+
for (int i = 0; i < 200; ++i) {
46+
fillNextBatchForTraining(inputs, targets);
47+
try (Tensor<Float> inputBatch = Tensors.create(inputs);
48+
Tensor<Float> targetBatch = Tensors.create(targets)) {
49+
sess.runner()
50+
.feed("input", inputBatch)
51+
.feed("target", targetBatch)
52+
.addTarget("train")
53+
.run();
54+
}
55+
}
56+
57+
System.out.println("Updated predictions");
58+
printPredictionsOnTestSet(sess);
59+
60+
System.out.println("Saving checkpoint");
61+
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/control_dependency").run();
62+
}
63+
}
64+
65+
public static void printPredictionsOnTestSet(Session sess) {
66+
final float[][][] inputBatch = new float[][][] {{{1.0f}}, {{2.0f}}, {{3.0f}}};
67+
try (Tensor<Float> input = Tensors.create(inputBatch);
68+
Tensor<Float> output =
69+
sess.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class)) {
70+
final long shape[] = output.shape();
71+
final int batchSize = (int) shape[0];
72+
final int rows = (int) shape[1];
73+
final int cols = (int) shape[2];
74+
float[][][] predictions = output.copyTo(new float[batchSize][rows][cols]);
75+
for (int i = 0; i < batchSize; ++i) {
76+
System.out.print("\t x = ");
77+
System.out.print(Arrays.deepToString(inputBatch[i]));
78+
System.out.print(", predicted y = ");
79+
System.out.println(Arrays.deepToString(predictions[i]));
80+
}
81+
}
82+
}
83+
84+
public static void fillNextBatchForTraining(float[][][] inputs, float[][][] targets) {
85+
final Random r = new Random();
86+
for (int i = 0; i < inputs.length; ++i) {
87+
inputs[i][0][0] = r.nextFloat();
88+
targets[i][0][0] = inputs[i][0][0] * 3.0f + 2.0f;
89+
}
90+
}
91+
}

0 commit comments

Comments
 (0)