Skip to content

Commit c64dc5d

Browse files
committed
Draft: Java API to use tf.function available on SavedModel.
Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function
1 parent c065b70 commit c64dc5d

File tree

2 files changed

+190
-0
lines changed

2 files changed

+190
-0
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright 2020 The TensorFlow Authors. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.tensorflow.framework.utils;
17+
18+
import com.google.protobuf.InvalidProtocolBufferException;
19+
20+
import java.util.HashMap;
21+
import java.util.Map;
22+
import java.util.stream.Collectors;
23+
24+
import org.tensorflow.proto.framework.MetaGraphDef;
25+
import org.tensorflow.proto.framework.SignatureDef;
26+
import org.tensorflow.SavedModelBundle;
27+
28+
public class SignatureToNodeName {
29+
30+
public SignatureToNodeName(SavedModelBundle savedModelBundle) {
31+
loadSignatures(savedModelBundle);
32+
}
33+
34+
public Map<String, String> inputNameToNode(String functionSignatureName) {
35+
NameContainer nc = this.functionMap.get(functionSignatureName);
36+
return (nc == null) ? null : nc.inputNameToNode();
37+
}
38+
39+
public Map<String, String> outputNameToNode(String functionSignatureName) {
40+
NameContainer nc = this.functionMap.get(functionSignatureName);
41+
return (nc == null) ? null : nc.outputNameToNode();
42+
}
43+
44+
public String methodName(String functionSignatureName) {
45+
NameContainer nc = this.functionMap.get(functionSignatureName);
46+
return (nc == null) ? null : nc.methodName();
47+
}
48+
49+
private void loadSignatures(SavedModelBundle savedModelBundle) {
50+
MetaGraphDef metaGraph = savedModelBundle.metaGraphDef();
51+
Map<String, SignatureDef> signatureMap = metaGraph.getSignatureDefMap();
52+
53+
// A saved model can contain multiple SignatureDef
54+
for (Map.Entry<String, SignatureDef> entry : signatureMap.entrySet()) {
55+
NameContainer nc = new NameContainer(entry.getValue());
56+
this.functionMap.put(entry.getKey(), nc);
57+
}
58+
}
59+
60+
private Map<String, NameContainer> functionMap = new HashMap<>();
61+
62+
private static final class NameContainer {
63+
NameContainer(SignatureDef sd) {
64+
this.inputNameToNodeName = sd.getInputsMap()
65+
.entrySet()
66+
.stream()
67+
.collect(Collectors.toMap(
68+
e -> e.getKey(),
69+
e -> e.getValue().getName()
70+
));
71+
72+
this.outputNameToNodeName = sd.getOutputsMap()
73+
.entrySet()
74+
.stream()
75+
.collect(Collectors.toMap(
76+
e -> e.getKey(),
77+
e -> e.getValue().getName()
78+
));
79+
80+
this.method = sd.getMethodName();
81+
}
82+
83+
public Map<String, String> inputNameToNode() {
84+
return this.inputNameToNodeName;
85+
}
86+
87+
public Map<String, String> outputNameToNode() {
88+
return this.outputNameToNodeName;
89+
}
90+
91+
public String methodName() {
92+
return this.method;
93+
}
94+
95+
private Map<String, String> inputNameToNodeName;
96+
private Map<String, String> outputNameToNodeName;
97+
private String method;
98+
}
99+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright 2020 The TensorFlow Authors. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.tensorflow.framework.utils;
17+
18+
import com.google.protobuf.InvalidProtocolBufferException;
19+
20+
import java.util.List;
21+
import java.util.ListIterator;
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
25+
import org.tensorflow.SavedModelBundle;
26+
import org.tensorflow.Session;
27+
import org.tensorflow.Tensor;
28+
29+
public class TfFunction {
30+
31+
public TfFunction(SavedModelBundle savedModelBundle) {
32+
this.nameToNode = new SignatureToNodeName(savedModelBundle);
33+
this.session = savedModelBundle.session();
34+
}
35+
36+
/**
37+
* Caller is responsible for closing all Tensors
38+
*/
39+
public Map<String, Tensor<?>> call(
40+
String functionSignatureName,
41+
Map<String, Tensor<?>> arguments) throws IllegalArgumentException {
42+
43+
Session.Runner runner = this.session.runner();
44+
45+
Map<String, String> inputToNode = this.nameToNode.inputNameToNode(functionSignatureName);
46+
47+
if (inputToNode == null) {
48+
throw new IllegalArgumentException(
49+
String.format("Function [%s] is missing input", functionSignatureName));
50+
}
51+
52+
// Join arguments.key, inputToNodeName.key
53+
for (Map.Entry<String, String> entry: inputToNode.entrySet()) {
54+
String argName = entry.getKey();
55+
Tensor<?> tensor = arguments.get(argName);
56+
57+
if (tensor == null) {
58+
throw new IllegalArgumentException(String.format("Missing argument [%s]", argName));
59+
}
60+
61+
// Node name in the tensorflow graph, corresponding to the tf.function argument
62+
runner = runner.feed(entry.getValue(), tensor);
63+
}
64+
65+
Map<String, String> outputToNode = this.nameToNode.outputNameToNode(functionSignatureName);
66+
if (outputToNode == null) {
67+
throw new IllegalArgumentException(
68+
String.format("Function [%] is missing output", functionSignatureName));
69+
}
70+
71+
for (String nodeName: outputToNode.values()) {
72+
// Node names corresponding to the return value
73+
runner = runner.fetch(nodeName);
74+
}
75+
76+
List<Tensor<?>> resultTensors = runner.run();
77+
ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator();
78+
79+
Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>();
80+
81+
// Use the output names as present in the signature definition
82+
for (String nodeName: outputToNode.keySet()) {
83+
returnMap.put(nodeName, resultTensorIter.next());
84+
}
85+
86+
return returnMap;
87+
}
88+
89+
private Session session;
90+
private SignatureToNodeName nameToNode;
91+
}

0 commit comments

Comments
 (0)