Skip to content

Commit 0cfc679

Browse files
committed
Fix metric test failures (tensorflow#414)
* Migrate metric tests from randomUniform to statelessRandomUniform * pom updates. * Spotless changes.
1 parent e6720e5 commit 0cfc679

File tree

9 files changed

+35
-32
lines changed

9 files changed

+35
-32
lines changed

pom.xml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
<maven.javadoc.skip>true</maven.javadoc.skip>
4747
<maven.source.skip>true</maven.source.skip>
4848
<gpg.skip>true</gpg.skip>
49-
<spotless.version>2.11.1</spotless.version>
49+
<spotless.version>2.20.2</spotless.version>
5050
</properties>
5151

5252
<repositories>
@@ -371,7 +371,9 @@
371371

372372
<lineEndings/>
373373
<java>
374-
<googleJavaFormat/>
374+
<googleJavaFormat>
375+
<version>1.14.0</version>
376+
</googleJavaFormat>
375377

376378
<removeUnusedImports/>
377379
</java>

tensorflow-core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
4444
Bumped to newer version to patch a CVE only present in protobuf-java
4545
-->
46-
<protobuf.version>3.19.2</protobuf.version>
46+
<protobuf.version>3.19.4</protobuf.version>
4747

4848
<native.classifier>${javacpp.platform}${javacpp.platform.extension}</native.classifier>
4949
<javacpp.build.skip>false</javacpp.build.skip> <!-- To skip execution of build.sh: -Djavacpp.build.skip=true -->

tensorflow-framework/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
<configuration>
9494
<forkCount>1</forkCount>
9595
<reuseForks>false</reuseForks>
96-
<argLine>-Xmx2G -XX:MaxPermSize=256m</argLine>
96+
<argLine>-Xmx2G</argLine>
9797
<includes>
9898
<include>**/*Test.java</include>
9999
</includes>

tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.tensorflow.ndarray.Shape;
2525
import org.tensorflow.op.Op;
2626
import org.tensorflow.op.Ops;
27-
import org.tensorflow.op.random.RandomUniform;
2827
import org.tensorflow.types.TFloat32;
2928
import org.tensorflow.types.TInt64;
3029

@@ -39,11 +38,11 @@ public void testValueIsIdempotent() {
3938
PrecisionAtRecall<TFloat32> instance = new PrecisionAtRecall<>(0.7f, 1001L, TFloat32.class);
4039

4140
Operand<TFloat32> predictions =
42-
tf.random.randomUniform(
43-
tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L));
41+
tf.random.statelessRandomUniform(
42+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1L, 0L}), TFloat32.class);
4443
Operand<TFloat32> labels =
45-
tf.random.randomUniform(
46-
tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L));
44+
tf.random.statelessRandomUniform(
45+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1L, 0L}), TFloat32.class);
4746

4847
Op update = instance.updateState(tf, labels, predictions, null);
4948

tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.tensorflow.ndarray.Shape;
2323
import org.tensorflow.op.Op;
2424
import org.tensorflow.op.Ops;
25-
import org.tensorflow.op.random.RandomUniform;
2625
import org.tensorflow.types.TFloat32;
2726
import org.tensorflow.types.TFloat64;
2827
import org.tensorflow.types.TInt32;
@@ -39,11 +38,11 @@ public void testValueIsIdempotent() {
3938
Precision<TFloat64> instance =
4039
new Precision<>(new float[] {0.3f, 0.72f}, 1001L, TFloat64.class);
4140
Operand<TFloat32> predictions =
42-
tf.random.randomUniform(
43-
tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1001L));
41+
tf.random.statelessRandomUniform(
42+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1001L, 0L}), TFloat32.class);
4443
Operand<TFloat32> labels =
45-
tf.random.randomUniform(
46-
tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1001L));
44+
tf.random.statelessRandomUniform(
45+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1001L, 0L}), TFloat32.class);
4746

4847
Op update = instance.updateState(tf, labels, predictions, null);
4948

@@ -81,7 +80,11 @@ public void testUnweightedAllIncorrect() {
8180
Precision<TFloat32> instance = new Precision<>(0.5f, 1001L, TFloat32.class);
8281

8382
Operand<TInt32> predictions =
84-
tf.random.randomUniformInt(tf.constant(Shape.of(100, 1)), tf.constant(0), tf.constant(2));
83+
tf.random.statelessMultinomial(
84+
tf.constant(new float[][] {{0.5f, 0.5f}}),
85+
tf.constant(100),
86+
tf.constant(new long[] {1001L, 0L}),
87+
TInt32.class);
8588
Operand<TInt32> labels = tf.math.sub(tf.constant(1), predictions);
8689
Op update = instance.updateState(tf, labels, predictions, null);
8790
session.run(update);

tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.tensorflow.ndarray.Shape;
2525
import org.tensorflow.op.Op;
2626
import org.tensorflow.op.Ops;
27-
import org.tensorflow.op.random.RandomUniform;
2827
import org.tensorflow.types.TFloat32;
2928
import org.tensorflow.types.TInt64;
3029

@@ -39,11 +38,11 @@ public void testValueIsIdempotent() {
3938
RecallAtPrecision<TFloat32> instance = new RecallAtPrecision<>(0.7f, 1001L, TFloat32.class);
4039

4140
Operand<TFloat32> predictions =
42-
tf.random.randomUniform(
43-
tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L));
41+
tf.random.statelessRandomUniform(
42+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1L, 0L}), TFloat32.class);
4443
Operand<TFloat32> labels =
45-
tf.random.randomUniform(
46-
tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L));
44+
tf.random.statelessRandomUniform(
45+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1L, 0L}), TFloat32.class);
4746
labels = tf.math.mul(labels, tf.constant(2.0f));
4847

4948
Op update = instance.updateState(tf, labels, predictions);

tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ public void testValueIsIdempotent() {
3636
Recall<TFloat32> instance = new Recall<>(new float[] {0.3f, 0.72f}, 1001L, TFloat32.class);
3737

3838
Operand<TFloat32> predictions =
39-
tf.random.randomUniform(tf.constant(Shape.of(10, 3)), TFloat32.class);
39+
tf.random.statelessRandomUniform(
40+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1L, 0L}), TFloat32.class);
4041
Operand<TFloat32> labels =
41-
tf.random.randomUniform(tf.constant(Shape.of(10, 3)), TFloat32.class);
42+
tf.random.statelessRandomUniform(
43+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1L, 0L}), TFloat32.class);
4244
Op update = instance.updateState(tf, labels, predictions, null);
4345

4446
for (int i = 0; i < 10; i++) session.run(update);

tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.tensorflow.ndarray.Shape;
2525
import org.tensorflow.op.Op;
2626
import org.tensorflow.op.Ops;
27-
import org.tensorflow.op.random.RandomUniform;
2827
import org.tensorflow.types.TFloat32;
2928
import org.tensorflow.types.TFloat64;
3029
import org.tensorflow.types.TInt64;
@@ -40,11 +39,11 @@ public void testValueIsIdempotent() {
4039
SensitivityAtSpecificity<TFloat32> instance =
4140
new SensitivityAtSpecificity<>(0.7f, 1001L, TFloat32.class);
4241
Operand<TFloat32> predictions =
43-
tf.random.randomUniform(
44-
tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L));
42+
tf.random.statelessRandomUniform(
43+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1L, 0L}), TFloat32.class);
4544
Operand<TFloat32> labels =
46-
tf.random.randomUniform(
47-
tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L));
45+
tf.random.statelessRandomUniform(
46+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1L, 0L}), TFloat32.class);
4847
labels = tf.math.mul(labels, tf.constant(2.0f));
4948

5049
// instance.setDebug(session.getGraphSession());

tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.tensorflow.ndarray.Shape;
2525
import org.tensorflow.op.Op;
2626
import org.tensorflow.op.Ops;
27-
import org.tensorflow.op.random.RandomUniform;
2827
import org.tensorflow.types.TFloat32;
2928
import org.tensorflow.types.TFloat64;
3029
import org.tensorflow.types.TInt32;
@@ -42,11 +41,11 @@ public void testValueIsIdempotent() {
4241
new SpecificityAtSensitivity<>(0.7f, 1001L, TFloat32.class);
4342

4443
Operand<TFloat32> predictions =
45-
tf.random.randomUniform(
46-
tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L));
44+
tf.random.statelessRandomUniform(
45+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1L, 0L}), TFloat32.class);
4746
Operand<TFloat32> labels =
48-
tf.random.randomUniform(
49-
tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L));
47+
tf.random.statelessRandomUniform(
48+
tf.constant(Shape.of(10, 3)), tf.constant(new long[] {1L, 0L}), TFloat32.class);
5049

5150
// instance.setDebug(session.getGraphSession());
5251
Op update = instance.updateState(tf, labels, predictions, null);

0 commit comments

Comments
 (0)