@@ -162,6 +162,9 @@ class MasterSession : public MasterSessionInterface {
162162 // nodes) are unique across all sub-graphs within this session.
163163 int64 next_node_id_ GUARDED_BY (mu_) = 0;
164164
165+ // Used to cancel running steps on Close().
166+ CancellationManager* cancellation_manager_;
167+
165168 // Private dtor. The client must call Close().
166169 virtual ~MasterSession ();
167170
@@ -219,7 +222,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
219222 int64 execution_count,
220223 SimpleGraphExecutionState* execution_state,
221224 PerStepState* pss, CallOptions* opts,
222- const RunStepRequest& req, RunStepResponse* resp);
225+ const RunStepRequest& req, RunStepResponse* resp,
226+ CancellationManager* cm);
223227
224228 // Calls workers to cleanup states for the step "step_id". Waits
225229 // till all cleanup rpcs complete.
@@ -504,7 +508,8 @@ class RunManyGraphs {
504508Status MasterSession::ReffedClientGraph::RunPartitions (
505509 const MasterEnv* env, int64 step_id, int64 execution_count,
506510 SimpleGraphExecutionState* execution_state, PerStepState* pss,
507- CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp) {
511+ CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp,
512+ CancellationManager* cm) {
508513 VLOG (2 ) << " RunPartitions step_id " << step_id << " execution_count "
509514 << execution_count;
510515 // Builds an index for feeds provided by the client.
@@ -560,7 +565,14 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
560565
561566 // Waits for the RunGraph calls.
562567 call_opts->SetCancelCallback ([&calls]() { calls.StartCancel (); });
568+ auto token = cm->get_cancellation_token ();
569+ bool success =
570+ cm->RegisterCallback (token, [&calls]() { calls.StartCancel (); });
571+ if (!success) {
572+ return errors::Cancelled (" Step was cancelled" );
573+ }
563574 calls.Wait ();
575+ cm->DeregisterCallback (token);
564576 call_opts->ClearCancelCallback ();
565577
566578 // Collects fetches.
@@ -696,7 +708,8 @@ MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
696708 env_ (env),
697709 handle_(strings::FpToString(random::New64())),
698710 graph_version_(0 ),
699- runs_(5 ) {
711+ runs_(5 ),
712+ cancellation_manager_(new CancellationManager) {
700713 UpdateLastAccessTime ();
701714
702715 swap (remote_devs_, *remote_devs);
@@ -717,6 +730,7 @@ MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
717730}
718731
719732MasterSession::~MasterSession () {
733+ delete cancellation_manager_;
720734 for (const auto & iter : runs_) iter.second ->Unref ();
721735 for (const auto & iter : obsolete_) iter.second ->Unref ();
722736 delete flib_def_;
@@ -892,8 +906,9 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
892906 const uint64 step_id = (random::New64 () & ((1uLL << 56 ) - 1 )) | (1uLL << 56 );
893907 TRACEPRINTF (" stepid %llu" , step_id);
894908
895- TF_RETURN_IF_ERROR (rcg->RunPartitions (
896- env_, step_id, count, execution_state_.get (), &pss, opts, *req, resp));
909+ TF_RETURN_IF_ERROR (rcg->RunPartitions (env_, step_id, count,
910+ execution_state_.get (), &pss, opts,
911+ *req, resp, cancellation_manager_));
897912
898913 pss.end_micros = Env::Default ()->NowMicros ();
899914
@@ -914,6 +929,7 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
914929}
915930
916931Status MasterSession::Close () {
932+ cancellation_manager_->StartCancel ();
917933 std::vector<ReffedClientGraph*> to_unref;
918934 {
919935 mutex_lock l (mu_);
0 commit comments