Skip to content

Commit 9360efe

Browse files
committed
Add initializers
1 parent e9cd56a commit 9360efe

17 files changed

+2150
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.initializers;
16+
17+
import org.junit.jupiter.api.*;
18+
import org.tensorflow.Operand;
19+
import org.tensorflow.framework.utils.TestSession;
20+
import org.tensorflow.ndarray.Shape;
21+
import org.tensorflow.op.Ops;
22+
import org.tensorflow.types.*;
23+
24+
import static org.junit.jupiter.api.Assertions.assertThrows;
25+
import static org.junit.jupiter.api.Assertions.fail;
26+
27+
/** Test the Constant initializer */
28+
public class ConstantTest {
29+
30+
private final TestSession.Mode tfMode = TestSession.Mode.EAGER;
31+
32+
public ConstantTest() {}
33+
34+
@BeforeAll
35+
public static void setUpClass() {}
36+
37+
@AfterAll
38+
public static void tearDownClass() {}
39+
40+
@BeforeEach
41+
public void setUp() {}
42+
43+
@AfterEach
44+
public void tearDown() {}
45+
46+
/** Test of call method, of class Constant. */
47+
@Test
48+
public void testCallUInt() {
49+
Byte[] expected = {0xf, 0xf, 0xf, 0xf}; // init to constant to make sure they all change to zero
50+
try (TestSession session = TestSession.createTestSession(tfMode)) {
51+
Ops tf = session.getTF();
52+
Shape shape = Shape.of(2, 2);
53+
Constant<TUint8> instance = new Constant<>(tf, 0xf);
54+
Operand<TUint8> operand = instance.call(tf.constant(shape), TUint8.DTYPE);
55+
session.evaluate(expected, operand);
56+
}
57+
}
58+
59+
/** Test of call method, of class Constant. */
60+
@Test
61+
public void testCallInt() {
62+
Integer[] expected = {
63+
0xf, 0xf, 0xf, 0xf
64+
}; // init to constant to make sure they all change to zero
65+
try (TestSession session = TestSession.createTestSession(tfMode)) {
66+
Ops tf = session.getTF();
67+
Shape shape = Shape.of(2, 2);
68+
Constant<TInt32> instance = new Constant<>(tf, 0xf);
69+
Operand<TInt32> operand = instance.call(tf.constant(shape), TInt32.DTYPE);
70+
session.evaluate(expected, operand);
71+
}
72+
}
73+
74+
/** Test of call method, of class Constant. */
75+
@Test
76+
public void testCallLong() {
77+
long[] expected = {
78+
0xffL, 0xffL, 0xffL, 0xffL
79+
}; // init to constant to make sure they all change to zero
80+
try (TestSession session = TestSession.createTestSession(tfMode)) {
81+
Ops tf = session.getTF();
82+
Shape shape = Shape.of(2, 2);
83+
Constant<TInt64> instance = new Constant<>(tf, 0xffL);
84+
Operand<TInt64> operand = instance.call(tf.constant(shape), TInt64.DTYPE);
85+
session.evaluate(expected, operand);
86+
}
87+
}
88+
89+
/** Test of call method, of class Constant. */
90+
@Test
91+
public void testCallFloat() {
92+
float[] expected = {12.f, 12.f, 12.f, 12.f};
93+
try (TestSession session = TestSession.createTestSession(tfMode)) {
94+
Ops tf = session.getTF();
95+
Shape shape = Shape.of(2, 2);
96+
Constant<TFloat32> instance = new Constant<>(tf, 12.F);
97+
Operand<TFloat32> operand = instance.call(tf.constant(shape), TFloat32.DTYPE);
98+
session.evaluate(expected, operand);
99+
}
100+
}
101+
102+
/** Test of call method, of class Constant. */
103+
@Test
104+
public void testCallDouble() {
105+
double[] expected = {11., 11., 11., 11.};
106+
try (TestSession session = TestSession.createTestSession(tfMode)) {
107+
Ops tf = session.getTF();
108+
Shape shape = Shape.of(2, 2);
109+
110+
Constant<TFloat64> instance = new Constant<>(tf, 11.);
111+
Operand<TFloat64> operand = instance.call(tf.constant(shape), TFloat64.DTYPE);
112+
session.evaluate(expected, operand);
113+
}
114+
}
115+
116+
/** Test of call method, of class Constant. */
117+
@Test
118+
public void testCallString() {
119+
assertThrows(
120+
java.lang.IllegalArgumentException.class,
121+
() -> {
122+
try (TestSession session = TestSession.createTestSession(tfMode)) {
123+
Ops tf = session.getTF();
124+
Shape shape = Shape.of(2, 2);
125+
126+
Constant<TString> instance = new Constant<>(tf, 22);
127+
instance.call(tf.constant(shape), TString.DTYPE);
128+
fail("IllegalArgumentException should have been thrown for TString");
129+
}
130+
});
131+
}
132+
133+
/** Test of call method, of class Constant. */
134+
@Test
135+
public void testCallBool() {
136+
try (TestSession session = TestSession.createTestSession(tfMode)) {
137+
Ops tf = session.getTF();
138+
Shape shape = Shape.of(2, 2);
139+
Boolean[] expected = {true, true, true, true};
140+
141+
Constant<TBool> instance = new Constant<>(tf, true);
142+
Operand<TBool> operand = instance.call(tf.constant(shape), TBool.DTYPE);
143+
session.evaluate(expected, operand);
144+
}
145+
}
146+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.initializers;
16+
17+
import org.junit.jupiter.api.*;
18+
import org.tensorflow.Operand;
19+
import org.tensorflow.framework.utils.TestSession;
20+
import org.tensorflow.ndarray.Shape;
21+
import org.tensorflow.op.Ops;
22+
import org.tensorflow.types.TFloat32;
23+
import org.tensorflow.types.TFloat64;
24+
25+
/** Test cases for GlorotNormal initializer */
26+
public class GlorotNormalTest {
27+
28+
private final TestSession.Mode tfMode = TestSession.Mode.EAGER;
29+
30+
private static final long SEED = 1000L;
31+
32+
public GlorotNormalTest() {}
33+
34+
@BeforeAll
35+
public static void setUpClass() {}
36+
37+
@AfterAll
38+
public static void tearDownClass() {}
39+
40+
@BeforeEach
41+
public void setUp() {}
42+
43+
@AfterEach
44+
public void tearDown() {}
45+
46+
/** Test of call method, of class GlorotNormal. */
47+
@Test
48+
public void testCall_Float() {
49+
float[] expected = {-0.52388954F, -0.29329166F, -0.07872587F, -0.31851602F};
50+
try (TestSession session = TestSession.createTestSession(tfMode)) {
51+
Ops tf = session.getTF();
52+
Shape shape = Shape.of(2, 2);
53+
GlorotNormal<TFloat32, TFloat32> instance = new GlorotNormal<>(tf, SEED);
54+
55+
Operand<TFloat32> operand = instance.call(tf.constant(shape), TFloat32.DTYPE);
56+
session.evaluate(expected, operand);
57+
}
58+
}
59+
60+
@Test
61+
public void testCall_Double() {
62+
double[] expected = {
63+
1.4971264721246893, -1.2488522307109322, -0.5409677352523339, 0.4871390504288623
64+
};
65+
try (TestSession session = TestSession.createTestSession(tfMode)) {
66+
Ops tf = session.getTF();
67+
Shape shape = Shape.of(2, 2);
68+
69+
GlorotNormal<TFloat64, TFloat64> instance = new GlorotNormal<>(tf, SEED);
70+
Operand<TFloat64> operand = instance.call(tf.constant(shape), TFloat64.DTYPE);
71+
session.evaluate(expected, operand);
72+
}
73+
}
74+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.initializers;
16+
17+
import java.util.HashMap;
18+
import java.util.Map;
19+
import org.junit.jupiter.api.AfterEach;
20+
import org.junit.jupiter.api.AfterAll;
21+
import org.junit.jupiter.api.BeforeEach;
22+
import org.junit.jupiter.api.BeforeAll;
23+
import org.junit.jupiter.api.Test;
24+
import static org.junit.jupiter.api.Assertions.*;
25+
import org.tensorflow.Operand;
26+
import org.tensorflow.framework.utils.TestSession;
27+
import org.tensorflow.ndarray.Shape;
28+
import org.tensorflow.op.Ops;
29+
import org.tensorflow.types.TFloat32;
30+
import org.tensorflow.types.TFloat64;
31+
32+
/**
33+
* Test cases for GlorotUniform initializer
34+
*/
35+
public class GlorotUniformTest {
36+
37+
private final TestSession.Mode tfMode = TestSession.Mode.EAGER;
38+
39+
private static final long SEED = 1000L;
40+
41+
int counter;
42+
43+
public GlorotUniformTest() {
44+
}
45+
46+
@BeforeAll
47+
public static void setUpClass() {
48+
}
49+
50+
@AfterAll
51+
public static void tearDownClass() {
52+
}
53+
54+
@BeforeEach
55+
public void setUp() {
56+
}
57+
58+
@AfterEach
59+
public void tearDown() {
60+
}
61+
62+
63+
/**
64+
* Test of call method, of class GlorotUniform.
65+
*/
66+
@Test
67+
public void testCall_Float() {
68+
float[] expected = {0.9266439F, 0.8190767F, 1.1268647F, 0.6596042F};
69+
try (TestSession session = TestSession.createTestSession(tfMode)) {
70+
Ops tf = session.getTF();
71+
Shape shape = Shape.of(2, 2);
72+
GlorotUniform<TFloat32, TFloat32> instance
73+
= new GlorotUniform<>(tf, SEED);
74+
Operand<TFloat32> operand = instance.call(tf.constant(shape), TFloat32.DTYPE);
75+
session.evaluate(expected, operand);
76+
}
77+
}
78+
79+
@Test
80+
public void testCall_Double() {
81+
double[] expected = {0.06468193804916589, 0.44170328686673477,
82+
0.06711059208157763, 0.6278720842445181};
83+
try (TestSession session = TestSession.createTestSession(tfMode)) {
84+
Ops tf = session.getTF();
85+
Shape shape = Shape.of(2, 2);
86+
GlorotUniform<TFloat64, TFloat64> instance
87+
= new GlorotUniform<>(tf, SEED);
88+
Operand<TFloat64> operand = instance.call(tf.constant(shape), TFloat64.DTYPE);
89+
session.evaluate(expected, operand);
90+
}
91+
}
92+
93+
}

0 commit comments

Comments
 (0)