|  | 
|  | 1 | +/* | 
|  | 2 | + * Copyright 2020 The TensorFlow Authors. All rights reserved. | 
|  | 3 | + * | 
|  | 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 5 | + * you may not use this file except in compliance with the License. | 
|  | 6 | + * You may obtain a copy of the License at | 
|  | 7 | + * | 
|  | 8 | + *     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 9 | + * | 
|  | 10 | + * Unless required by applicable law or agreed to in writing, software | 
|  | 11 | + * distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 13 | + * See the License for the specific language governing permissions and | 
|  | 14 | + * limitations under the License. | 
|  | 15 | + */ | 
|  | 16 | +package org.tensorflow.framework.utils; | 
|  | 17 | + | 
|  | 18 | +import com.google.protobuf.InvalidProtocolBufferException; | 
|  | 19 | + | 
|  | 20 | +import java.util.List; | 
|  | 21 | +import java.util.ListIterator; | 
|  | 22 | +import java.util.HashMap; | 
|  | 23 | +import java.util.Map; | 
|  | 24 | + | 
|  | 25 | +import org.tensorflow.SavedModelBundle; | 
|  | 26 | +import org.tensorflow.Session; | 
|  | 27 | +import org.tensorflow.Tensor; | 
|  | 28 | + | 
|  | 29 | +public class TfFunction { | 
|  | 30 | + | 
|  | 31 | +  public TfFunction(SavedModelBundle savedModelBundle) { | 
|  | 32 | +    this.nameToNode = new SignatureToNodeName(savedModelBundle); | 
|  | 33 | +    this.session = savedModelBundle.session(); | 
|  | 34 | +  } | 
|  | 35 | + | 
|  | 36 | +  /** | 
|  | 37 | +   * Caller is responsible for closing all Tensors | 
|  | 38 | +   */ | 
|  | 39 | +  public Map<String, Tensor<?>> call( | 
|  | 40 | +    String functionSignatureName, | 
|  | 41 | +    Map<String, Tensor<?>> arguments) throws IllegalArgumentException { | 
|  | 42 | + | 
|  | 43 | +    Session.Runner runner = this.session.runner(); | 
|  | 44 | + | 
|  | 45 | +    Map<String, String> inputToNode = this.nameToNode.inputNameToNode(functionSignatureName); | 
|  | 46 | + | 
|  | 47 | +    if (inputToNode == null) { | 
|  | 48 | +      throw new IllegalArgumentException( | 
|  | 49 | +        String.format("Function [%s] is missing input", functionSignatureName)); | 
|  | 50 | +    } | 
|  | 51 | + | 
|  | 52 | +    // Join arguments.key, inputToNodeName.key | 
|  | 53 | +    for (Map.Entry<String, String> entry: inputToNode.entrySet()) { | 
|  | 54 | +      String argName = entry.getKey(); | 
|  | 55 | +      Tensor<?> tensor = arguments.get(argName); | 
|  | 56 | + | 
|  | 57 | +      if (tensor == null) { | 
|  | 58 | +        throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); | 
|  | 59 | +      } | 
|  | 60 | + | 
|  | 61 | +      // Node name in the tensorflow graph, corresponding to the tf.function argument | 
|  | 62 | +      runner = runner.feed(entry.getValue(), tensor); | 
|  | 63 | +    } | 
|  | 64 | + | 
|  | 65 | +    Map<String, String> outputToNode = this.nameToNode.outputNameToNode(functionSignatureName); | 
|  | 66 | +    if (outputToNode == null) { | 
|  | 67 | +      throw new IllegalArgumentException( | 
|  | 68 | +        String.format("Function [%] is missing output", functionSignatureName)); | 
|  | 69 | +    } | 
|  | 70 | + | 
|  | 71 | +    for (String nodeName: outputToNode.values()) { | 
|  | 72 | +      // Node names corresponding to the return value | 
|  | 73 | +      runner = runner.fetch(nodeName); | 
|  | 74 | +    } | 
|  | 75 | + | 
|  | 76 | +    List<Tensor<?>> resultTensors = runner.run(); | 
|  | 77 | +    ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator(); | 
|  | 78 | + | 
|  | 79 | +    Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>(); | 
|  | 80 | + | 
|  | 81 | +    // Use the output names as present in the signature definition | 
|  | 82 | +    for (String nodeName: outputToNode.keySet()) { | 
|  | 83 | +      returnMap.put(nodeName, resultTensorIter.next()); | 
|  | 84 | +    } | 
|  | 85 | + | 
|  | 86 | +    return returnMap; | 
|  | 87 | +  } | 
|  | 88 | + | 
|  | 89 | +  private Session session; | 
|  | 90 | +  private SignatureToNodeName nameToNode; | 
|  | 91 | +} | 
0 commit comments