1616package org .tensorflow ;
1717
1818import static org .tensorflow .Graph .resolveOutputs ;
19- import static org .tensorflow .internal .c_api .global .tensorflow .TF_CloseSession ;
20- import static org .tensorflow .internal .c_api .global .tensorflow .TF_DeleteSession ;
21- import static org .tensorflow .internal .c_api .global .tensorflow .TF_NewSession ;
19+ import static org .tensorflow .internal .c_api .global .tensorflow .TF_OperationGetAttrType ;
2220import static org .tensorflow .internal .c_api .global .tensorflow .TF_SessionRun ;
2321import static org .tensorflow .internal .c_api .global .tensorflow .TF_SetConfig ;
2422
3836import org .tensorflow .internal .c_api .TF_SessionOptions ;
3937import org .tensorflow .internal .c_api .TF_Status ;
4038import org .tensorflow .internal .c_api .TF_Tensor ;
39+ import org .tensorflow .internal .types .registry .TensorTypeRegistry ;
4140import org .tensorflow .op .Op ;
41+ import org .tensorflow .op .Ops ;
42+ import org .tensorflow .op .core .ReadVariableOp ;
4243import org .tensorflow .proto .framework .ConfigProto ;
44+ import org .tensorflow .proto .framework .DataType ;
4345import org .tensorflow .proto .framework .RunMetadata ;
4446import org .tensorflow .proto .framework .RunOptions ;
4547import org .tensorflow .proto .util .SaverDef ;
@@ -192,6 +194,11 @@ public Runner feed(String operation, int index, Tensor t) {
192194 * @return this session runner
193195 */
194196 public Runner feed (Operand <?> operand , Tensor t ) {
197+ if (operand .env () != graph ) {
198+ throw new IllegalStateException ("Can't feed value for operand " + operand + ", it is from " +
199+ (operand .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
200+ }
201+
195202 inputs .add (operand .asOutput ());
196203 inputTensors .add (t );
197204 return this ;
@@ -200,6 +207,8 @@ public Runner feed(Operand<?> operand, Tensor t) {
200207 /**
201208 * Make {@link #run()} return the output of {@code operation}.
202209 *
210+ * If the output is a resource variable, will fetch the value.
211+ *
203212 * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code
204213 * fetch(operation, 0)}, or it is a string of the form
205214 * <tt>operation_name:output_index</tt> , in which case this method acts like {@code
@@ -215,6 +224,8 @@ public Runner fetch(String operation) {
215224 /**
216225 * Make {@link #run()} return the {@code index}-th output of {@code operation}.
217226 *
227+ * If the output is a resource variable, will fetch the value.
228+ *
218229 * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
219230 * one to return.
220231 *
@@ -225,24 +236,61 @@ public Runner fetch(String operation) {
225236 */
226237 public Runner fetch (String operation , int index ) {
227238 Operation op = graph .operationOrThrow (operation );
228- outputs .add (op .output (index ));
229- return this ;
239+ return fetch (op .output (index ));
230240 }
231241
232242 /**
233243 * Makes {@link #run()} return the Tensor referred to by {@code output}.
234244 *
245+ * If {@code output} is a resource variable, will fetch the value.
246+ *
235247 * @param output the node to fetch the tensor from
236248 * @return this session runner
237249 */
238250 public Runner fetch (Output <?> output ) {
239- outputs .add (output );
251+ if (output .env () != graph ) {
252+ throw new IllegalStateException ("Can't fetch output " + output + ", it is from " +
253+ (output .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
254+ }
255+
256+ if (output .dataType () == DataType .DT_RESOURCE ) {
257+ int [] rawDt = new int [1 ];
258+
259+ GraphOperation graphOp = (GraphOperation ) output .op ();
260+
261+ try (PointerScope scope = new PointerScope ()) {
262+ TF_Status status = TF_Status .newStatus ();
263+ TF_OperationGetAttrType (graphOp .getUnsafeNativeHandle (), "dtype" , rawDt , status );
264+ status .throwExceptionIfNotOK ();
265+ }
266+
267+ DataType valueDt = DataType .forNumber (rawDt [0 ]);
268+
269+ Operand <?> read = null ;
270+ for (GraphOperation op : graphOp .consumers ()) {
271+ if (op .dtype (0 ) == valueDt && op .type ().equals (ReadVariableOp .OP_NAME )) {
272+ read = op .output (0 );
273+ break ;
274+ }
275+ }
276+
277+ if (read == null ) {
278+ read = Ops .create (graph ).withSubScope ("session_reads" ).withName (output .op ().name () + "_read" )
279+ .readVariableOp (output , TensorTypeRegistry .find (valueDt ).type ());
280+ }
281+
282+ outputs .add (read .asOutput ());
283+ } else {
284+ outputs .add (output );
285+ }
240286 return this ;
241287 }
242288
243289 /**
244290 * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
245291 *
292+ * If {@code operand} is a resource variable, will fetch the value.
293+ *
246294 * @param operand the node to fetch the tensor from, as an operand
247295 * @return this session runner
248296 */
@@ -258,9 +306,7 @@ public Runner fetch(Operand<?> operand) {
258306 * @throws IllegalArgumentException if no operation exists with the provided name
259307 */
260308 public Runner addTarget (String operation ) {
261- GraphOperation op = graph .operationOrThrow (operation );
262- targets .add (op );
263- return this ;
309+ return addTarget (graph .operationOrThrow (operation ));
264310 }
265311
266312 /**
@@ -269,13 +315,12 @@ public Runner addTarget(String operation) {
269315 * @param operation the operation to execute
270316 * @return this session runner
271317 * @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
318+ * @throws IllegalStateException if the operation is not from the session's graph.
272319 */
273320 public Runner addTarget (Operation operation ) {
274- if (!(operation instanceof GraphOperation )) {
275- throw new IllegalArgumentException (
276- "Operation of type "
277- + operation .getClass ().getName ()
278- + " is not supported in graph sessions" );
321+ if (operation .env () != graph ) {
322+ throw new IllegalStateException ("Can't target operation " + operation + ", it is from " +
323+ (operation .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
279324 }
280325 targets .add ((GraphOperation ) operation );
281326 return this ;
@@ -594,12 +639,12 @@ private static void delete(TF_Session handle) {
594639 *
595640 * @param handle to the C API TF_Session object (Session.nativeHandle)
596641 * @param runOptions A RunOptions protocol buffer, or null
597- * @param inputOpHandles (see inputOpIndices)
598- * @param inputOpIndices (see inputTensorHandles)
599642 * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed"
600643 * (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a
601644 * Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus,
602645 * it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
646+ * @param inputOpHandles (see inputOpIndices)
647+ * @param inputOpIndices (see inputTensorHandles)
603648 * @param outputOpHandles (see outputOpIndices)
604649 * @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The
605650 * outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length ==
0 commit comments