Skip to content

Commit 55547dd

Browse files
authored
Metrics init scope (#382)
1 parent ef72244 commit 55547dd

File tree

87 files changed

+2734
-2362
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+2734
-2362
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java

Lines changed: 132 additions & 181 deletions
Large diffs are not rendered by default.

tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
import static org.tensorflow.framework.utils.CastHelper.cast;
1818

19+
import org.tensorflow.Graph;
1920
import org.tensorflow.Operand;
2021
import org.tensorflow.framework.losses.impl.LossTuple;
2122
import org.tensorflow.framework.metrics.impl.LossMetric;
22-
import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
23+
import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper;
2324
import org.tensorflow.framework.metrics.impl.MetricsHelper;
2425
import org.tensorflow.ndarray.Shape;
2526
import org.tensorflow.op.Ops;
@@ -36,50 +37,54 @@
3637
*
3738
* @param <T> The data type for the metric result
3839
*/
39-
public class Accuracy<T extends TNumber> extends MeanMetricWrapper<T> implements LossMetric<T> {
40+
public class Accuracy<T extends TNumber> extends MeanBaseMetricWrapper<T> implements LossMetric {
4041

4142
/**
4243
* Creates an Accuracy Metric using {@link Class#getSimpleName()} for the metric name
4344
*
44-
* @param tf the TensorFlow Ops
4545
* @param seed the seed for random number generation. An initializer created with a given seed
4646
* will always produce the same random tensor for a given shape and data type.
4747
* @param type the data type for the variables
4848
*/
49-
public Accuracy(Ops tf, long seed, Class<T> type) {
50-
this(tf, null, seed, type);
49+
public Accuracy(long seed, Class<T> type) {
50+
this(null, seed, type);
5151
}
5252

5353
/**
5454
* Creates an Accuracy Metric
5555
*
56-
* @param tf the TensorFlow Ops
5756
* @param name the name of the metric, if null then {@link Class#getSimpleName()} is used
5857
* @param seed the seed for random number generation. An initializer created with a given seed
5958
* will always produce the same random tensor for a given shape and data type.
6059
* @param type the data type for the variables
6160
*/
62-
public Accuracy(Ops tf, String name, long seed, Class<T> type) {
63-
super(tf, name, seed, type);
61+
public Accuracy(String name, long seed, Class<T> type) {
62+
super(name, seed, type);
6463
setLoss(this);
6564
}
6665

6766
/**
6867
* Calculates how often predictions equals labels. {@code labels} and {@code predictions} must
6968
* have compatible shapes, see {@link Shape @isCompatibleWith}.
7069
*
70+
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
7171
* @param labels the truth values or labels
7272
* @param predictions the predictions
73-
* @throws IllegalArgumentException if predictions and labels shapes are not compatible.
73+
* @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph
74+
* environment.
7475
* @return the loss
7576
*/
7677
@Override
77-
public Operand<T> call(
78-
Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions) {
79-
Operand<T> tLabels = cast(getTF(), labels, getResultType());
80-
Operand<T> tPredictions = cast(getTF(), predictions, getResultType());
78+
public <U extends TNumber> Operand<U> call(
79+
Ops tf,
80+
Operand<? extends TNumber> labels,
81+
Operand<? extends TNumber> predictions,
82+
Class<U> resultType) {
83+
init(tf);
84+
Operand<T> tLabels = cast(tf, labels, getInternalType());
85+
Operand<T> tPredictions = cast(tf, predictions, getInternalType());
8186
LossTuple<T> tuple =
82-
MetricsHelper.raggedAssertCompatibleAndGetFlatValues(getTF(), tLabels, tPredictions);
87+
MetricsHelper.raggedAssertCompatibleAndGetFlatValues(tf, tLabels, tPredictions);
8388
tLabels = tuple.getLabels();
8489
tPredictions = tuple.getTarget();
8590

@@ -91,6 +96,6 @@ public Operand<T> call(
9196
}
9297

9398
// cast TBool to result type
94-
return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType());
99+
return cast(tf, tf.math.equal(tLabels, tPredictions), resultType);
95100
}
96101
}
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.metrics;
16+
17+
import java.util.Collections;
18+
import java.util.List;
19+
import org.tensorflow.Graph;
20+
import org.tensorflow.Operand;
21+
import org.tensorflow.op.Op;
22+
import org.tensorflow.op.Ops;
23+
import org.tensorflow.types.family.TNumber;
24+
25+
/** Base class for Metrics */
26+
public abstract class BaseMetric implements Metric {
27+
28+
/** The seed for random number generation */
29+
private final long seed;
30+
31+
private String name;
32+
33+
private boolean initialized;
34+
35+
private Ops tf;
36+
37+
/**
38+
* Creates a Metric with a name of {@link Class#getSimpleName()}
39+
*
40+
* @param seed the seed for random number generation. An initializer created with a given seed
41+
* will always produce the same random tensor for a given shape and data type.
42+
*/
43+
protected BaseMetric(long seed) {
44+
this(null, seed);
45+
}
46+
47+
/**
48+
* Creates a Metric
49+
*
50+
* @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}.
51+
* @param seed the seed for random number generation. An initializer created with a given seed
52+
* will always produce the same random tensor for a given shape and data type.
53+
*/
54+
protected BaseMetric(String name, long seed) {
55+
56+
this.seed = seed;
57+
this.name = name != null ? name : this.getClass().getSimpleName();
58+
}
59+
60+
/**
61+
* Creates a List of Operations to update the metric state based on input values.
62+
*
63+
* <p>This is an empty implementation that should be overridden in a subclass, if needed.
64+
*
65+
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
66+
* @param values the inputs to be passed to update state, this may not be null
67+
* @param sampleWeights sample weights to be applied to the values, may be null.
68+
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
69+
* @return a List of Operations to update the metric state
70+
*/
71+
@SuppressWarnings({"unchecked", "unused"})
72+
@Override
73+
public List<Op> updateStateList(
74+
Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) {
75+
checkIsGraph(tf);
76+
return Collections.EMPTY_LIST;
77+
}
78+
79+
/**
80+
* Creates a List of Operations to update the metric state based on labels and predictions.
81+
*
82+
* <p>This is an empty implementation that should be overridden in a subclass, if needed.
83+
*
84+
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
85+
* @param labels the labels
86+
* @param predictions the predictions
87+
* @param sampleWeights sample weights to be applied to the metric values, may be null.
88+
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
89+
* @return a List of Operations to update the metric state
90+
*/
91+
@Override
92+
@SuppressWarnings({"unchecked", "unused"})
93+
public List<Op> updateStateList(
94+
Ops tf,
95+
Operand<? extends TNumber> labels,
96+
Operand<? extends TNumber> predictions,
97+
Operand<? extends TNumber> sampleWeights) {
98+
checkIsGraph(tf);
99+
return Collections.EMPTY_LIST;
100+
}
101+
102+
/**
103+
* Creates a NoOp Operation with control dependencies to update the metric state
104+
*
105+
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
106+
* @param values the inputs to be passed to update state, this may not be null
107+
* @param sampleWeights sample weights to be applied to the values, may be null.
108+
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
109+
* @return the Operation to update the metric state
110+
*/
111+
public final Op updateState(
112+
Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) {
113+
checkIsGraph(tf);
114+
List<Op> controlOps = updateStateList(tf, values, sampleWeights);
115+
return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp();
116+
}
117+
118+
/**
119+
* Creates a NoOp Operation with control dependencies to update the metric state
120+
*
121+
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
122+
* @param labels the labels
123+
* @param predictions the predictions
124+
* @param sampleWeights sample weights to be applied to the metric values, may be null.
125+
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
126+
* @return the Operation to update the metric state
127+
*/
128+
public final Op updateState(
129+
Ops tf,
130+
Operand<? extends TNumber> labels,
131+
Operand<? extends TNumber> predictions,
132+
Operand<? extends TNumber> sampleWeights) {
133+
List<Op> controlOps = updateStateList(tf, labels, predictions, sampleWeights);
134+
return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp();
135+
}
136+
137+
/**
138+
* Calls update state once, followed by a call to get the result
139+
*
140+
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
141+
* @param values the inputs to be passed to update state, this may not be null
142+
* @param sampleWeights sample weights to be applied to the values, may be null.
143+
* @param <T> The data type for the metric result
144+
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
145+
* @return the result, possibly with control dependencies
146+
*/
147+
@Override
148+
public final <T extends TNumber> Operand<T> callOnce(
149+
Ops tf,
150+
Operand<? extends TNumber> values,
151+
Operand<? extends TNumber> sampleWeights,
152+
Class<T> type) {
153+
checkIsGraph(tf);
154+
List<Op> controlOps = updateStateList(tf, values, sampleWeights);
155+
Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps);
156+
return ltf.identity(result(ltf, type));
157+
}
158+
159+
/**
160+
* Gets a formatted name for a variable, in the form {@link #name} + "_" + varName.
161+
*
162+
* @param varName the base name for the variable
163+
* @return the formatted variable name
164+
*/
165+
protected String getVariableName(String varName) {
166+
return String.format("%s_%s", this.name, varName);
167+
}
168+
169+
/**
170+
* The name for this metric. Defaults to {@link Class#getSimpleName()}.
171+
*
172+
* <p>Gets the name of this metric.
173+
*
174+
* @return the name of this metric
175+
*/
176+
public String getName() {
177+
return name;
178+
}
179+
180+
/**
181+
* Sets the metric name
182+
*
183+
* @param name the metric name
184+
*/
185+
public void setName(String name) {
186+
this.name = name;
187+
}
188+
189+
/**
190+
* Gets the random number generator seed value
191+
*
192+
* @return the random number generator seed value
193+
*/
194+
public long getSeed() {
195+
return seed;
196+
}
197+
198+
/**
199+
* Initialize the TensorFlow Ops
200+
*
201+
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
202+
* @throws IllegalArgumentException if the TensorFlow Ops does not have a Graph environment,
203+
*/
204+
protected abstract void init(Ops tf);
205+
206+
/**
207+
* Gets the TensorFlow Ops for this metric
208+
*
209+
* @return the TensorFlow Ops for this metric.
210+
*/
211+
protected Ops getTF() {
212+
return tf;
213+
}
214+
215+
/**
216+
* Sets the TensorFlow Ops for this metric.
217+
*
218+
* <p>This should be set from the {@link #init(Ops)} implementation.
219+
*
220+
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
221+
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
222+
*/
223+
protected void setTF(Ops tf) {
224+
checkIsGraph(tf);
225+
this.tf = tf;
226+
}
227+
228+
/**
229+
* Checks whether the Metric is initialized or not.
230+
*
231+
* @return true if the Metric has been initialized.
232+
*/
233+
public boolean isInitialized() {
234+
return initialized;
235+
}
236+
237+
/**
238+
* Sets the initialized indicator
239+
*
240+
* @param initialized the initialized indicator
241+
*/
242+
protected void setInitialized(boolean initialized) {
243+
this.initialized = initialized;
244+
}
245+
246+
/**
247+
* Checks if the TensorFlow Ops encapsulates a {@link Graph} environment.
248+
*
249+
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
250+
* @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph
251+
* environment.
252+
*/
253+
protected void checkIsGraph(Ops tf) {
254+
if (!tf.scope().env().isGraph()) {
255+
throw new IllegalArgumentException(
256+
"The Ops environment is not a Graph, Graph is required for metrics.");
257+
}
258+
}
259+
}

0 commit comments

Comments
 (0)