|  | 
|  | 1 | +/* | 
|  | 2 | + * Copyright 2020 The TensorFlow Authors. All rights reserved. | 
|  | 3 | + * | 
|  | 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 5 | + * you may not use this file except in compliance with the License. | 
|  | 6 | + * You may obtain a copy of the License at | 
|  | 7 | + * | 
|  | 8 | + *     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 9 | + * | 
|  | 10 | + * Unless required by applicable law or agreed to in writing, software | 
|  | 11 | + * distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 13 | + * See the License for the specific language governing permissions and | 
|  | 14 | + * limitations under the License. | 
|  | 15 | + */ | 
|  | 16 | +package org.tensorflow; | 
|  | 17 | + | 
|  | 18 | +import com.google.protobuf.InvalidProtocolBufferException; | 
|  | 19 | + | 
|  | 20 | +import java.util.List; | 
|  | 21 | +import java.util.ListIterator; | 
|  | 22 | +import java.util.HashMap; | 
|  | 23 | +import java.util.Map; | 
|  | 24 | + | 
|  | 25 | +/** | 
|  | 26 | + * Invoke <a href="https://www.tensorflow.org/api_docs/python/tf/function">tf.function</a> | 
|  | 27 | + * defined in a {@link SavedModelBundle}. | 
|  | 28 | + * | 
|  | 29 | + * <pre>{@code | 
|  | 30 | + * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName"); | 
|  | 31 | + * Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap); | 
|  | 32 | + * }</pre> | 
|  | 33 | + * | 
|  | 34 | + */ | 
|  | 35 | +public class TfFunction { | 
|  | 36 | + | 
|  | 37 | +  public TfFunction( | 
|  | 38 | +    String functionSignatureName, | 
|  | 39 | +    SavedModelBundle.SignatureToNodeName nameToNode, Session session) { | 
|  | 40 | +    this.nameToNode = nameToNode; | 
|  | 41 | +    this.session = session; | 
|  | 42 | +    this.functionSignatureName = functionSignatureName; | 
|  | 43 | +  } | 
|  | 44 | + | 
|  | 45 | +  /** | 
|  | 46 | +   * Invokes a tf.function. | 
|  | 47 | +   * Caller is responsible for closing all Tensors. | 
|  | 48 | +   * | 
|  | 49 | +   * @param arguments map of input tensors | 
|  | 50 | +   * @return map of output tensors | 
|  | 51 | +   */ | 
|  | 52 | +  public Map<String, Tensor<?>> call( | 
|  | 53 | +    Map<String, Tensor<?>> arguments) throws IllegalArgumentException { | 
|  | 54 | + | 
|  | 55 | +    Session.Runner runner = this.session.runner(); | 
|  | 56 | + | 
|  | 57 | +    Map<String, String> inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); | 
|  | 58 | + | 
|  | 59 | +    if (inputToNode == null) { | 
|  | 60 | +      throw new IllegalArgumentException( | 
|  | 61 | +        String.format("Function [%s] is missing input", this.functionSignatureName)); | 
|  | 62 | +    } | 
|  | 63 | + | 
|  | 64 | +    // Join arguments.key, inputToNodeName.key | 
|  | 65 | +    for (Map.Entry<String, String> entry: inputToNode.entrySet()) { | 
|  | 66 | +      String argName = entry.getKey(); | 
|  | 67 | +      Tensor<?> tensor = arguments.get(argName); | 
|  | 68 | + | 
|  | 69 | +      if (tensor == null) { | 
|  | 70 | +        throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); | 
|  | 71 | +      } | 
|  | 72 | + | 
|  | 73 | +      // Node name in the tensorflow graph, corresponding to the tf.function argument | 
|  | 74 | +      runner = runner.feed(entry.getValue(), tensor); | 
|  | 75 | +    } | 
|  | 76 | + | 
|  | 77 | +    Map<String, String> outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); | 
|  | 78 | +    if (outputToNode == null) { | 
|  | 79 | +      throw new IllegalArgumentException( | 
|  | 80 | +        String.format("Function [%] is missing output", this.functionSignatureName)); | 
|  | 81 | +    } | 
|  | 82 | + | 
|  | 83 | +    for (String nodeName: outputToNode.values()) { | 
|  | 84 | +      // Node names corresponding to the return value | 
|  | 85 | +      runner = runner.fetch(nodeName); | 
|  | 86 | +    } | 
|  | 87 | + | 
|  | 88 | +    List<Tensor<?>> resultTensors = runner.run(); | 
|  | 89 | +    ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator(); | 
|  | 90 | + | 
|  | 91 | +    Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>(); | 
|  | 92 | + | 
|  | 93 | +    // Use the output names as present in the signature definition | 
|  | 94 | +    for (String nodeName: outputToNode.keySet()) { | 
|  | 95 | +      returnMap.put(nodeName, resultTensorIter.next()); | 
|  | 96 | +    } | 
|  | 97 | + | 
|  | 98 | +    return returnMap; | 
|  | 99 | +  } | 
|  | 100 | + | 
|  | 101 | +  /** | 
|  | 102 | +   * Invokes a tf.function. | 
|  | 103 | +   * Caller is responsible for closing all Tensors. | 
|  | 104 | +   * | 
|  | 105 | +   * Throws IllegalArgumentException if there are multiple input or output parameters defined | 
|  | 106 | +   * in the tf.function | 
|  | 107 | +   * | 
|  | 108 | +   * @param tensor input tensor | 
|  | 109 | +   * @return output tensor | 
|  | 110 | +   */ | 
|  | 111 | +  public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException { | 
|  | 112 | +    Session.Runner runner = this.session.runner(); | 
|  | 113 | + | 
|  | 114 | +    Map<String, String> inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); | 
|  | 115 | + | 
|  | 116 | +    if (inputToNode == null) { | 
|  | 117 | +      throw new IllegalArgumentException( | 
|  | 118 | +        String.format("Function [%s] is missing input", this.functionSignatureName)); | 
|  | 119 | +    } | 
|  | 120 | + | 
|  | 121 | +    if (inputToNode.size() != 1) { | 
|  | 122 | +      throw new IllegalArgumentException( | 
|  | 123 | +        String.format("Function [%s] requires multiple inputs", this.functionSignatureName)); | 
|  | 124 | +    } | 
|  | 125 | + | 
|  | 126 | +    // Feed the single argument | 
|  | 127 | +    for (Map.Entry<String, String> entry: inputToNode.entrySet()) { | 
|  | 128 | +      // Node name in the tensorflow graph, corresponding to the tf.function argument | 
|  | 129 | +      runner = runner.feed(entry.getValue(), tensor); | 
|  | 130 | +    } | 
|  | 131 | + | 
|  | 132 | +    Map<String, String> outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); | 
|  | 133 | +    if (outputToNode == null) { | 
|  | 134 | +      throw new IllegalArgumentException( | 
|  | 135 | +        String.format("Function [%] is missing output", this.functionSignatureName)); | 
|  | 136 | +    } | 
|  | 137 | + | 
|  | 138 | +    if (outputToNode.size() != 1) { | 
|  | 139 | +      throw new IllegalArgumentException( | 
|  | 140 | +        String.format("Function [%s] has multiple outputs", this.functionSignatureName)); | 
|  | 141 | +    } | 
|  | 142 | + | 
|  | 143 | +    // Fetch the single return tensor | 
|  | 144 | +    for (String nodeName: outputToNode.values()) { | 
|  | 145 | +      // Node names corresponding to the return value | 
|  | 146 | +      runner = runner.fetch(nodeName); | 
|  | 147 | +    } | 
|  | 148 | + | 
|  | 149 | +    List<Tensor<?>> resultTensors = runner.run(); | 
|  | 150 | + | 
|  | 151 | +    return resultTensors.get(0); | 
|  | 152 | +  } | 
|  | 153 | + | 
|  | 154 | +  private final Session session; | 
|  | 155 | +  private final SavedModelBundle.SignatureToNodeName nameToNode; | 
|  | 156 | +  private final String functionSignatureName; | 
|  | 157 | +} | 
0 commit comments