Skip to content

Commit e876504

Browse files
mrrytensorflower-gardener
authored andcommitted
Adds support for cancellation when closing a gRPC Session.
This brings the behavior of the gRPC Session in line with the in-process session. Change: 122447672
1 parent 75b1f2f commit e876504

File tree

2 files changed

+44
-5
lines changed

2 files changed

+44
-5
lines changed

tensorflow/core/distributed_runtime/master_session.cc

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ class MasterSession : public MasterSessionInterface {
162162
// nodes) are unique across all sub-graphs within this session.
163163
int64 next_node_id_ GUARDED_BY(mu_) = 0;
164164

165+
// Used to cancel running steps on Close().
166+
CancellationManager* cancellation_manager_;
167+
165168
// Private dtor. The client must call Close().
166169
virtual ~MasterSession();
167170

@@ -219,7 +222,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
219222
int64 execution_count,
220223
SimpleGraphExecutionState* execution_state,
221224
PerStepState* pss, CallOptions* opts,
222-
const RunStepRequest& req, RunStepResponse* resp);
225+
const RunStepRequest& req, RunStepResponse* resp,
226+
CancellationManager* cm);
223227

224228
// Calls workers to cleanup states for the step "step_id". Waits
225229
// till all cleanup rpcs complete.
@@ -504,7 +508,8 @@ class RunManyGraphs {
504508
Status MasterSession::ReffedClientGraph::RunPartitions(
505509
const MasterEnv* env, int64 step_id, int64 execution_count,
506510
SimpleGraphExecutionState* execution_state, PerStepState* pss,
507-
CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp) {
511+
CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp,
512+
CancellationManager* cm) {
508513
VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
509514
<< execution_count;
510515
// Builds an index for feeds provided by the client.
@@ -560,7 +565,14 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
560565

561566
// Waits for the RunGraph calls.
562567
call_opts->SetCancelCallback([&calls]() { calls.StartCancel(); });
568+
auto token = cm->get_cancellation_token();
569+
bool success =
570+
cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
571+
if (!success) {
572+
return errors::Cancelled("Step was cancelled");
573+
}
563574
calls.Wait();
575+
cm->DeregisterCallback(token);
564576
call_opts->ClearCancelCallback();
565577

566578
// Collects fetches.
@@ -696,7 +708,8 @@ MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
696708
env_(env),
697709
handle_(strings::FpToString(random::New64())),
698710
graph_version_(0),
699-
runs_(5) {
711+
runs_(5),
712+
cancellation_manager_(new CancellationManager) {
700713
UpdateLastAccessTime();
701714

702715
swap(remote_devs_, *remote_devs);
@@ -717,6 +730,7 @@ MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
717730
}
718731

719732
MasterSession::~MasterSession() {
733+
delete cancellation_manager_;
720734
for (const auto& iter : runs_) iter.second->Unref();
721735
for (const auto& iter : obsolete_) iter.second->Unref();
722736
delete flib_def_;
@@ -892,8 +906,9 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
892906
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
893907
TRACEPRINTF("stepid %llu", step_id);
894908

895-
TF_RETURN_IF_ERROR(rcg->RunPartitions(
896-
env_, step_id, count, execution_state_.get(), &pss, opts, *req, resp));
909+
TF_RETURN_IF_ERROR(rcg->RunPartitions(env_, step_id, count,
910+
execution_state_.get(), &pss, opts,
911+
*req, resp, cancellation_manager_));
897912

898913
pss.end_micros = Env::Default()->NowMicros();
899914

@@ -914,6 +929,7 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
914929
}
915930

916931
Status MasterSession::Close() {
932+
cancellation_manager_->StartCancel();
917933
std::vector<ReffedClientGraph*> to_unref;
918934
{
919935
mutex_lock l(mu_);

tensorflow/python/training/server_lib_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import time
21+
2022
import numpy as np
2123
import tensorflow as tf
2224

@@ -79,6 +81,27 @@ def testLargeFeed(self):
7981
self.assertEqual(0.5, min_val)
8082
self.assertEqual(0.5, max_val)
8183

84+
def testCloseCancelsBlockingOperation(self):
85+
server = tf.train.Server.create_local_server()
86+
sess = tf.Session(server.target)
87+
88+
q = tf.FIFOQueue(10, [tf.float32])
89+
enqueue_op = q.enqueue(37.0)
90+
dequeue_t = q.dequeue()
91+
92+
sess.run(enqueue_op)
93+
sess.run(dequeue_t)
94+
95+
def blocking_dequeue():
96+
with self.assertRaises(tf.errors.CancelledError):
97+
sess.run(dequeue_t)
98+
99+
blocking_thread = self.checkedThread(blocking_dequeue)
100+
blocking_thread.start()
101+
time.sleep(0.5)
102+
sess.close()
103+
blocking_thread.join()
104+
82105
def testInvalidHostname(self):
83106
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "port"):
84107
_ = tf.train.Server({"local": ["localhost"]},

0 commit comments

Comments
 (0)