- 
                Notifications
    You must be signed in to change notification settings 
- Fork 219
Save models as functions #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
bcd533f
              0dbdd3e
              5f17f38
              b311e27
              175d9e6
              99a7450
              1383a38
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,274 @@ | ||
| /* | ||
| * Copyright 2020 The TensorFlow Authors. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.tensorflow; | ||
|  | ||
| import java.util.List; | ||
| import java.util.ListIterator; | ||
| import java.util.HashMap; | ||
| import java.util.Map; | ||
| import java.util.function.Function; | ||
| import org.tensorflow.op.Ops; | ||
| import org.tensorflow.op.math.Sign; | ||
| import org.tensorflow.proto.framework.SignatureDef; | ||
| import org.tensorflow.proto.framework.TensorInfo; | ||
|  | ||
| /** | ||
| * A graph that can be invoked as a single function, with an input and output signature. | ||
| * | ||
| * <p>A function can also invoke a | ||
| * <a href="https://www.tensorflow.org/api_docs/python/tf/function">tf.function</a> | ||
| * defined in a {@link SavedModelBundle}. | ||
| * | ||
| * <pre>{@code | ||
| * FunctionGraph myFunction = savedModelBundle.function("myFunctionSignatureName"); | ||
| * Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap); | ||
| * }</pre> | ||
| */ | ||
| public class FunctionGraph implements AutoCloseable { | ||
|  | ||
| /** | ||
| * Creates a function by building a new graph. | ||
| * | ||
| * <p/>The {@code functionBuilder} must initialize the function graph from the provided | ||
| * {@link Ops} instance and return a valid signature that will be used to feed the input tensors | ||
| * and fetch the output tensors on execution. | ||
| * | ||
| * <p/>The function will be the owner of the new graph and its resulting session. Therefore, | ||
| * the function must be enclosed properly with a try-with-resources block to guarantee that | ||
| * all native resources will be freed once the function is discarded. For example: | ||
| * | ||
| * <pre>{@code | ||
| * public class MyModel { | ||
| * | ||
| * public static Signature addTwo(Ops tf) { | ||
| * Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE); | ||
| * Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f)); | ||
| * return Signature.builder("addTwo").input("x", input).output("y", output).build(); | ||
| * } | ||
| * | ||
| * public static void main(String args[]) { | ||
| * try (FunctionGraph function = FunctionGraph.create(MyModel::addTwo); | ||
| * Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) { | ||
| * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); | ||
| * } | ||
| * } | ||
| * } | ||
| * }</pre> | ||
| * | ||
| * @param functionBuilder function builder | ||
| * @return the new function | ||
| */ | ||
| public static FunctionGraph create(Function<Ops, Signature> functionBuilder) { | ||
| Graph graph = new Graph(); | ||
| try { | ||
| Ops tf = Ops.create(graph); | ||
| Signature signature = functionBuilder.apply(tf); | ||
| return new FunctionGraph(signature, graph, new Session(graph), Ownership.GRAPH); | ||
| } catch (Exception e) { | ||
| graph.close(); | ||
| throw e; | ||
| } | ||
| } | ||
|  | ||
| /** | ||
| * Create a function from a signature and an existing graph. | ||
| * | ||
| * <p/>The function will keep the ownership of the session used to run the graph but not | ||
| * the graph itself, meaning that the lifetime of the latter can extend beyond the scope | ||
| * of the function. For example: | ||
| * | ||
| * <pre>{@code | ||
| * try (Graph g = new Graph()) { | ||
| * Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE); | ||
| * Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f)); | ||
| * Signature signature = Signature.builder().input("x", input).output("y", output).build(); | ||
| * | ||
| * try (FunctionGraph f = FunctionGraph.create(signature, g); | ||
| * Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) { | ||
| * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); | ||
| * } | ||
| * // Graph g is still valid at this point | ||
| * } | ||
| * }</pre> | ||
| * | ||
| * @param signature signature of the function to create | ||
| * @param graph a valid and initialized graph | ||
| * @return a new function | ||
| */ | ||
| public static FunctionGraph create(Signature signature, Graph graph) { | ||
| return new FunctionGraph(signature, graph, new Session(graph), Ownership.SESSION); | ||
| } | ||
|  | ||
| /** | ||
| * Create a function from a signature and a valid graph session. | ||
| * | ||
| * <p/>The function will not own the session nor its graph, meaning that their lifetime | ||
| * can extend beyond the scope of the function. Therefore the function does not need to be | ||
| * closed after its usage. For example: | ||
| * | ||
| * <pre>{@code | ||
| * try (Graph g = new Graph()) { | ||
| * Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE); | ||
| * Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f)); | ||
| * Signature signature = Signature.builder().input("x", input).output("y", output).build(); | ||
| * | ||
| * try (Session s = new Session(g)) { | ||
| * // Auto-closing the function just as an example but this is not required since it has | ||
| * // no effect | ||
| * try (FunctionGraph f = FunctionGraph.create(signature, s); | ||
| * Tensor<TFloat32> t = TFloat32.scalarOf(2.0f)) { | ||
| * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); | ||
| * } | ||
| * // Session s is still valid at this point | ||
| * } | ||
| * // Graph g is still valid at this point | ||
| * } | ||
| * }</pre> | ||
| * | ||
| * @param signature signature of the function to create | ||
| * @param graph a valid session to an initialized graph | ||
| * @return a new function | ||
| */ | ||
| public static FunctionGraph create(Signature signature, Session session) { | ||
| return new FunctionGraph(signature, session.graph(), session, Ownership.NONE); | ||
| } | ||
|  | ||
| /** | ||
| * Returns the signature of this function | ||
| */ | ||
| public Signature signature() { | ||
| return signature; | ||
| } | ||
|  | ||
| /** | ||
| * Invokes a function. | ||
| * | ||
| * <p>Caller is responsible for closing all Tensors. | ||
| * | ||
| * @param tensor input tensor | ||
| * @return output tensor | ||
| */ | ||
| public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments) | ||
| throws IllegalArgumentException { | ||
|  | ||
| final SignatureDef signatureDef = signature.asSignatureDef(); | ||
| final Session.Runner runner = session.runner(); | ||
|  | ||
| signatureDef.getInputsMap().forEach((argName, t) -> { | ||
| Tensor<?> tensor = arguments.get(argName); | ||
| if (tensor == null) { | ||
| throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); | ||
| } | ||
| runner.feed(t.getName(), tensor); | ||
| }); | ||
|  | ||
| Map<String, TensorInfo> outputToNode = signatureDef.getOutputsMap(); | ||
| outputToNode.values().forEach(t -> runner.fetch(t.getName())); | ||
|  | ||
| List<Tensor<?>> resultTensors = runner.run(); | ||
| try { | ||
| ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator(); | ||
| Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>(); | ||
|  | ||
| // Use the output names as present in the signature definition | ||
| for (String nodeName: outputToNode.keySet()) { | ||
| returnMap.put(nodeName, resultTensorIter.next()); | ||
| } | ||
| return returnMap; | ||
|  | ||
| } catch (Exception e) { | ||
| // Release tensors before throwing exception | ||
| for (Tensor<?> t : resultTensors) { | ||
| t.close(); | ||
| } | ||
| throw e; | ||
| } | ||
| } | ||
|  | ||
| /** | ||
| * Invokes a function with a single input and output. | ||
| * | ||
| * <p>Caller is responsible for closing all Tensors. | ||
| * | ||
| * @param tensor input tensor | ||
| * @return output tensor | ||
| * @throws IllegalArgumentException if there are multiple input or output parameters defined | ||
| * in the function | ||
| */ | ||
| public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException { | ||
| final SignatureDef signatureDef = signature.asSignatureDef(); | ||
|  | ||
| if (signatureDef.getInputsCount() != 1) { | ||
| throw new IllegalArgumentException( | ||
| String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); | ||
| } | ||
| String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName(); | ||
|  | ||
| if (signatureDef.getOutputsCount() != 1) { | ||
| throw new IllegalArgumentException( | ||
| String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); | ||
| } | ||
| String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName(); | ||
|  | ||
| return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); | ||
| } | ||
|  | ||
| /** | ||
| * Returns the session used to execute the graph when calling this function | ||
| * | ||
| * <p>In general, a user does not need to handle directly the session of a function and rely | ||
| * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to | ||
| * the session might be necessary, as it allows more running options. | ||
| * | ||
| * @return the function session | ||
| */ | ||
| public Session session() { | ||
| return session; | ||
| } | ||
|  | ||
| /** | ||
| * Returns the graph of this function | ||
| */ | ||
| public Graph graph() { | ||
| return graph; | ||
| } | ||
|  | ||
| @Override | ||
| public void close() { | ||
| if (ownership != Ownership.NONE) { | ||
| session.close(); | ||
| if (ownership == Ownership.GRAPH) { | ||
| graph.close(); | ||
| } | ||
| } | ||
| } | ||
|  | ||
| private enum Ownership { | ||
| GRAPH, SESSION, NONE; | ||
|          | ||
| } | ||
|  | ||
| private final Graph graph; | ||
| private final Session session; | ||
| private final Signature signature; | ||
| private final Ownership ownership; | ||
|  | ||
| FunctionGraph(Signature signature, Graph graph, Session session, Ownership ownership) { | ||
| this.graph = graph; | ||
| this.session = session; | ||
| this.signature = signature; | ||
| this.ownership = ownership; | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -43,8 +43,17 @@ | |
| import org.tensorflow.internal.c_api.TF_Output; | ||
| import org.tensorflow.internal.c_api.TF_Status; | ||
| import org.tensorflow.internal.c_api.TF_WhileParams; | ||
| import org.tensorflow.ndarray.StdArrays; | ||
| import org.tensorflow.op.Op; | ||
| import org.tensorflow.op.Ops; | ||
| import org.tensorflow.op.core.Constant; | ||
| import org.tensorflow.op.core.NoOp; | ||
| import org.tensorflow.op.core.Placeholder; | ||
| import org.tensorflow.op.train.Restore; | ||
| import org.tensorflow.op.train.Save; | ||
| import org.tensorflow.proto.framework.GraphDef; | ||
| import org.tensorflow.proto.util.SaverDef; | ||
| import org.tensorflow.types.TString; | ||
|  | ||
|  | ||
| /** | ||
|  | @@ -67,6 +76,11 @@ public Graph() { | |
| this.nativeHandle = nativeHandle; | ||
| } | ||
|  | ||
| Graph(TF_Graph nativeHandle, SaverDef saverDef) { | ||
| this(nativeHandle); | ||
| this.saverDef = saverDef; | ||
| } | ||
|  | ||
| /** | ||
| * Release resources associated with the Graph. | ||
| * | ||
|  | @@ -287,6 +301,17 @@ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { | |
| return addGradients(null, new Output<?>[] {y}, x, null); | ||
| } | ||
|  | ||
| public SaverDef saverDef() { | ||
| if (saverDef == null) { | ||
| synchronized (this) { | ||
| if (saverDef == null) { | ||
| saverDef = addVariableSaver(this); | ||
|          | ||
| } | ||
| } | ||
| } | ||
| return saverDef; | ||
| } | ||
|  | ||
| /** | ||
| * Used to instantiate an abstract class which overrides the buildSubgraph method to build a | ||
| * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to | ||
|  | @@ -405,6 +430,7 @@ public Output<?>[] whileLoop( | |
| private final Object nativeHandleLock = new Object(); | ||
| private TF_Graph nativeHandle; | ||
| private int refcount = 0; | ||
| private SaverDef saverDef; | ||
|  | ||
| private final List<Op> initializers = new ArrayList<>(); | ||
|  | ||
|  | @@ -726,6 +752,53 @@ private static Object[] whileLoop( | |
| } | ||
| } | ||
|  | ||
| private static SaverDef addVariableSaver(Graph graph) { | ||
| Ops tf = Ops.create(graph).withSubScope("save"); | ||
|  | ||
| List<String> varNames = new ArrayList<>(); | ||
| List<Operand<?>> varOutputs = new ArrayList<>(); | ||
| List<DataType<?>> varTypes = new ArrayList<>(); | ||
|  | ||
| for (Iterator<Operation> iter = graph.operations(); iter.hasNext();) { | ||
| Operation op = iter.next(); | ||
| if (op.type().equals("VariableV2")) { | ||
| varNames.add(op.name()); | ||
| varOutputs.add(op.output(0)); | ||
| varTypes.add(op.output(0).dataType()); | ||
| } | ||
| } | ||
|  | ||
| // FIXME Need an easier way to initialize an NdArray from a list | ||
| String[] tmp = new String[varNames.size()]; | ||
| Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp))); | ||
| Operand<TString> varSlices = tf.zerosLike(varNamesTensor); | ||
|  | ||
| Placeholder<TString> saveFilename = tf.placeholder(TString.DTYPE); | ||
| Save saveVariables = tf.train.save( | ||
| saveFilename, | ||
| varNamesTensor, | ||
| varSlices, | ||
| varOutputs | ||
| ); | ||
| Restore restoreVariables = tf.train.restore( | ||
| saveFilename, | ||
| varNamesTensor, | ||
| varSlices, | ||
| varTypes | ||
| ); | ||
| List<Op> restoreOps = new ArrayList<>(varOutputs.size()); | ||
| for (int i = 0; i < varOutputs.size(); ++i) { | ||
| restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i))); | ||
| } | ||
| NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp(); | ||
|  | ||
| return SaverDef.newBuilder() | ||
| .setFilenameTensorName(saveFilename.op().name()) | ||
| .setSaveTensorName(saveVariables.op().name()) | ||
| .setRestoreOpName(restoreAll.op().name()) | ||
| .build(); | ||
| } | ||
|  | ||
| static { | ||
| TensorFlow.init(); | ||
| } | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
FunctionSessionas it's a specialised session? That way it conceptually lives next toEagerSessionandSession, rather than next toGraph. I feel like this is much closer to a session than a graph.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, it is true that all functions loaded from a saved model share the same graph and are just used to execute it with a given signature (therefore they are acting more as a session). Nonetheless, I still have a preference to preserve the
FunctionGraphnaming.If we ignore the saving part (which is limited due to the actual state of the C API), conceptually these functions also allow you to build your graphs, replacing the need of allocating explicitly a
Graphinstance, e.g.Each function has its own graph and therefore if very coupled with that concept (where
EagerSessionhas no graph at all). They will also appear as separate graphs when exporting them into a saved model the same way Python does it (in fact, Python saves each function as one or more "objects", which are then linked to their distinct graph in the function library).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another idea would be to call this class a
ConcreteFunction. Concrete functions and functions are two very similar but distinct concepts in TF Python. The former is a typed realization of a function graph while the latter is its polymorphic version, only acting as a facade (e.g. a function is eventually composed of multiple concrete functions, one for each type of operands the function has been called with).In our scenario here, the function graphs are strongly typed and refer to a single graph, therefore more behave like a concrete function in the Python paradigm and it would make sense to name them after it.
We can probably support polymorphic functions too in the future, where we pass the type of the input tensors in parameter to the function builder. So we will still need to find a proper name for this concept, which will probably result in class encapsulating one or more
ConcreteFunction. Maybe simplyPolymorphicFunction?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My preference for a non-graph name is that
Graphs aren't executable, and are missing necessary state because theVariables live in theSession. SoFunctionGraphis more like aSessionbecause it contains all the necessary bits to execute. However I'm also fine with theConcreteFunctionname as that's just a new concept that people have to learn.