Skip to content

Commit b920172

Browse files
committed
[gpu]: Update to 1.4.0
1 parent c27c760 commit b920172

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

gpu/README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
The TensorFlow Java API is distributed in the
44
[org.tensorflow:tensorflow](http://mvnrepository.com/artifact/org.tensorflow/tensorflow)
5-
maven package. As of TensorFlow 1.2.1, this package did not include GPU support
5+
maven package. As of TensorFlow 1.4.0, this package did not include GPU support
66
by default. However, by making the GPU-enabled native libraries available
77
to the JVM, GPUs can be utilized. This example demonstrates that.
88

@@ -26,7 +26,7 @@ to the JVM, GPUs can be utilized. This example demonstrates that.
2626
```
2727
output: (MatMul): /job:localhost/replica:0/task:0/cpu:0
2828
input: (Placeholder): /job:localhost/replica:0/task:0/cpu:0
29-
TensorFlow version: 1.2.1
29+
TensorFlow version: 1.4.0
3030
3131
Input : [[1.0, 2.0], [3.0, 4.0]]
3232
Output: [[7.0, 10.0], [15.0, 22.0]]
@@ -43,8 +43,8 @@ CUDA-enabled Java native libraries need to be made available to the JVM.
4343
1. Download and extract the CUDA-enabled TensorFlow native library:
4444
4545
```
46-
curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.2.1.tar.gz" |
47-
tar -xz ./libtensorflow_jni.so
46+
curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-gpu-linux-x86_64-1.4.0.tar.gz" |
47+
tar -xz
4848
```
4949
5050
(For more detailed instructions see:
@@ -65,7 +65,7 @@ JVM will result in something like this on the console:
6565
```
6666
output: (MatMul): /job:localhost/replica:0/task:0/gpu:0
6767
input: (Placeholder): /job:localhost/replica:0/task:0/gpu:0
68-
TensorFlow version: 1.2.1
68+
TensorFlow version: 1.4.0
6969

7070
Input : [[1.0, 2.0], [3.0, 4.0]]
7171
Output: [[7.0, 10.0], [15.0, 22.0]]

gpu/pom.xml

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
<dependency>
1515
<groupId>org.tensorflow</groupId>
1616
<artifactId>tensorflow</artifactId>
17-
<version>1.2.1</version>
17+
<version>1.4.0</version>
1818
</dependency>
1919
<dependency>
2020
<groupId>org.tensorflow</groupId>
2121
<artifactId>proto</artifactId>
22-
<version>1.2.1</version>
22+
<version>1.4.0</version>
2323
</dependency>
2424
</dependencies>
2525
</project>

gpu/src/main/java/HelloTF.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.tensorflow.Session;
66
import org.tensorflow.Tensor;
77
import org.tensorflow.TensorFlow;
8+
import org.tensorflow.Tensors;
89
import org.tensorflow.framework.ConfigProto;
910

1011
public class HelloTF {
@@ -18,8 +19,9 @@ public static void main(String[] args) throws Exception {
1819
// Create a config that will dump out device placement of operations.
1920
ConfigProto config = ConfigProto.newBuilder().setLogDevicePlacement(true).build();
2021
try (Session sess = new Session(graph, config.toByteArray())) {
21-
try (Tensor in = Tensor.create(new float[][] {{1, 2}, {3, 4}});
22-
Tensor out = sess.runner().feed("input", in).fetch("output").run().get(0)) {
22+
try (Tensor<Float> in = Tensors.create(new float[][] {{1, 2}, {3, 4}});
23+
Tensor<Float> out =
24+
sess.runner().feed("input", in).fetch("output").run().get(0).expect(Float.class)) {
2325
System.out.println("TensorFlow version: " + TensorFlow.version());
2426
System.out.println();
2527
print2x2Matrix("Input ", in);
@@ -29,7 +31,7 @@ public static void main(String[] args) throws Exception {
2931
}
3032
}
3133

32-
public static void print2x2Matrix(String tag, Tensor t) {
34+
public static void print2x2Matrix(String tag, Tensor<Float> t) {
3335
float[][] m = new float[2][2];
3436
System.out.print(tag + ": ");
3537
System.out.println(Arrays.deepToString(t.copyTo(m)));

0 commit comments

Comments
 (0)