Skip to content

Commit 748435b

Browse files
authored
Fixed the issue that each invocation of model.fit/evaluate/predict modifies the (tensorflow#23280)
graph. PiperOrigin-RevId: 218793646
1 parent f90c214 commit 748435b

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

tensorflow/contrib/tpu/python/tpu/keras_support.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,25 @@
9797

9898
# TODO(b/114775106): temporary shim to optionally initialize the TPU
9999
# This increases the odds our session is initialized, but shouldn't be needed.
100+
_TEST_REWRITE_OP = None
101+
102+
100103
def _maybe_initialize_tpu(session):
101104
"""Initialize the TPU if it has not already been initialized."""
105+
global _TEST_REWRITE_OP
102106
try:
107+
# Try to use cached version to avoid another ground of graph optimization.
108+
test_rewrite_op = _TEST_REWRITE_OP
109+
if (test_rewrite_op is None or
110+
test_rewrite_op[0].graph != ops.get_default_graph()):
111+
112+
def test_op():
113+
return constant_op.constant(1) + constant_op.constant(1)
103114

104-
def test_op():
105-
return constant_op.constant(1) + constant_op.constant(1)
115+
test_rewrite_op = tpu.rewrite(test_op)
116+
_TEST_REWRITE_OP = test_rewrite_op
106117

107-
session.run(tpu.rewrite(test_op))
118+
session.run(test_rewrite_op)
108119
except errors.FailedPreconditionError as _:
109120
session.run(tpu.initialize_system())
110121

0 commit comments

Comments
 (0)