Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference request cancellation APIs #249

Merged
merged 12 commits into from
Sep 7, 2023
27 changes: 26 additions & 1 deletion include/triton/core/tritonbackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ struct TRITONBACKEND_Batcher;
/// }
///
#define TRITONBACKEND_API_VERSION_MAJOR 1
#define TRITONBACKEND_API_VERSION_MINOR 15
#define TRITONBACKEND_API_VERSION_MINOR 16

/// Get the TRITONBACKEND API version supported by Triton. This value
/// can be compared against the TRITONBACKEND_API_VERSION_MAJOR and
Expand Down Expand Up @@ -375,6 +375,31 @@ TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_OutputBufferAttributes(
TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestId(
TRITONBACKEND_Request* request, const char** id);

/// Query whether the request is cancelled or not.
///
/// If possible the backend should terminate any processing and
/// send an error response with cancelled status.
///
/// \param request The inference request.
/// \param is_cancelled Returns true if the request is cancelled otherwise it
/// would return false.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestIsCancelled(
TRITONBACKEND_Request* request, bool* is_cancelled);

/// Query whether the response factory is cancelled or not.
///
/// If possible the backend should terminate any processing and
/// send an error response with cancelled status.
///
/// \param factory The response factory
/// \param is_cancelled Returns true if the request is cancelled otherwise it
/// would return false.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONBACKEND_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_ResponseFactoryIsCancelled(
TRITONBACKEND_ResponseFactory* factory, bool* is_cancelled);

/// Get the correlation ID of the request if it is an unsigned integer.
/// Zero indicates that the request does not have a correlation ID.
/// Returns failure if correlation ID for given request is not an unsigned
Expand Down
33 changes: 31 additions & 2 deletions include/triton/core/tritonserver.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct TRITONSERVER_MetricFamily;
/// }
///
#define TRITONSERVER_API_VERSION_MAJOR 1
#define TRITONSERVER_API_VERSION_MINOR 24
#define TRITONSERVER_API_VERSION_MINOR 25

/// Get the TRITONBACKEND API version supported by the Triton shared
/// library. This value can be compared against the
Expand Down Expand Up @@ -308,7 +308,8 @@ typedef enum TRITONSERVER_errorcode_enum {
TRITONSERVER_ERROR_INVALID_ARG,
TRITONSERVER_ERROR_UNAVAILABLE,
TRITONSERVER_ERROR_UNSUPPORTED,
TRITONSERVER_ERROR_ALREADY_EXISTS
TRITONSERVER_ERROR_ALREADY_EXISTS,
TRITONSERVER_ERROR_CANCELLED
} TRITONSERVER_Error_Code;

/// Create a new error object. The caller takes ownership of the
Expand Down Expand Up @@ -1091,6 +1092,34 @@ TRITONSERVER_InferenceRequestSetCorrelationIdString(
struct TRITONSERVER_InferenceRequest* inference_request,
const char* correlation_id);

/// Cancel an inference request. Requests are canceled on a best
/// effort basis and no guarantee is provided that cancelling a
/// request will result in early termination. Note that the
/// inference request cancellation status will be reset after
/// TRITONSERVER_InferAsync is run. This means that if you cancel
/// the request before calling TRITONSERVER_InferAsync
/// the request will not be cancelled.
///
/// \param inference_request The request object.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_DECLSPEC struct TRITONSERVER_Error*
TRITONSERVER_InferenceRequestCancel(
struct TRITONSERVER_InferenceRequest* inference_request);

/// Query whether the request is cancelled or not.
///
/// If possible the backend should terminate any processing and
/// send an error response with cancelled status.
///
/// \param inference_request The request object.
/// \param is_cancelled Returns whether the inference request is cancelled or
/// not.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_DECLSPEC struct TRITONSERVER_Error*
TRITONSERVER_InferenceRequestIsCancelled(
struct TRITONSERVER_InferenceRequest* inference_request,
bool* is_cancelled);

/// Deprecated. See TRITONSERVER_InferenceRequestPriorityUInt64 instead.
///
/// Get the priority for a request. The default is 0 indicating that
Expand Down
21 changes: 21 additions & 0 deletions src/backend_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,16 @@ TRITONBACKEND_RequestFlags(TRITONBACKEND_Request* request, uint32_t* flags)
return nullptr; // success
}

TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_RequestIsCancelled(
TRITONBACKEND_Request* request, bool* is_cancelled)
{
InferenceRequest* tr = reinterpret_cast<InferenceRequest*>(request);
*is_cancelled = tr->IsCancelled();
return nullptr;
}


TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_RequestCorrelationIdString(
TRITONBACKEND_Request* request, const char** id)
Expand Down Expand Up @@ -1365,6 +1375,17 @@ TRITONBACKEND_ResponseFactorySendFlags(
return nullptr; // success
}

TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_ResponseFactoryIsCancelled(
TRITONBACKEND_ResponseFactory* factory, bool* is_cancelled)
{
std::shared_ptr<InferenceResponseFactory>* response_factory =
reinterpret_cast<std::shared_ptr<InferenceResponseFactory>*>(factory);
*is_cancelled = (*response_factory)->IsCancelled();
return nullptr; // success
}


///
/// TRITONBACKEND_Response
///
Expand Down
80 changes: 38 additions & 42 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,86 +106,81 @@ InferenceRequest::InferenceRequest(
: needs_normalization_(true), model_raw_(model),
requested_model_version_(requested_model_version), flags_(0),
correlation_id_(0), batch_size_(0), timeout_us_(0), collect_stats_(true),
state_(InferenceRequest::State::INITIALIZED), null_request_(false),
decrement_pending_count_(false)
state_(InferenceRequest::State::INITIALIZED), null_request_(false)
{
SetPriority(0);
}

InferenceRequest::~InferenceRequest()
{
// If request has been enqueued but hasn't started executing by destruction
// time, an error occurred and the pending request count will need to be
// decremented.
DecrementPendingRequestCount();
}


Status
InferenceRequest::SetState(InferenceRequest::State new_state)
{
LOG_VERBOSE(1) << LogRequest() << "Setting state from " << state_ << " to "
<< new_state;
// No-op if this is already the current state, or if this is a null request.
if (new_state == state_ || null_request_) {
return Status::Success;
}

// Allow RELEASED state transition from any state for now.
// Not all requests will follow linear transition, such as null requests
// used for padding batches, and ensemble requests.
if (new_state == InferenceRequest::State::RELEASED) {
state_ = new_state;
return Status::Success;
}

// Generate error when called rather than copying it into every case below.
const auto generate_error = [&]() {
std::stringstream ss;
ss << LogRequest() << "Invalid request state transition from " << state_
<< " to " << new_state;
return Status(Status::Code::INVALID_ARG, ss.str());
return Status(Status::Code::INTERNAL, ss.str());
};

// Define state transitions
switch (state_) {
case InferenceRequest::State::INITIALIZED: {
if (new_state != InferenceRequest::State::STARTED) {
if (new_state == InferenceRequest::State::PENDING) {
IncrementPendingRequestCount();
} else if (new_state == InferenceRequest::State::RELEASED) {
// No-op when moving from initialized to released, just releasing early.
} else {
return generate_error();
}
state_ = new_state;
IncrementPendingRequestCount();
break;
}
case InferenceRequest::State::STARTED: {
if (new_state != InferenceRequest::State::EXECUTING) {
case InferenceRequest::State::PENDING: {
// Request may move from pending to either execution when scheduled to
// backend, or released early due to some error.
if (new_state == InferenceRequest::State::EXECUTING ||
new_state == InferenceRequest::State::RELEASED) {
DecrementPendingRequestCount();
} else {
// Unexpected state transition
return generate_error();
}
state_ = new_state;
DecrementPendingRequestCount();
break;
}
case InferenceRequest::State::EXECUTING: {
if (new_state != InferenceRequest::State::RELEASED) {
return generate_error();
}
state_ = new_state;
break;
}
case InferenceRequest::State::RELEASED: {
// No state transition currently supported after release.
return generate_error();
if (new_state != InferenceRequest::State::INITIALIZED) {
// Only transition currently supported after release is to start over
// again, such as re-using request objects for multiple inferences.
return generate_error();
}
break;
}
}
state_ = new_state;
return Status::Success;
}

void
InferenceRequest::IncrementPendingRequestCount()
{
#ifdef TRITON_ENABLE_METRICS
// Pending request count should always be 0 or 1 per-request. If a request
// increments the count, it should not be incremented again until decremented.
auto reporter = model_raw_->MetricReporter();
if (reporter) {
reporter->IncrementGauge(kPendingRequestMetric, 1);
decrement_pending_count_ = true;
}
#endif // TRITON_ENABLE_METRICS
}
Expand All @@ -194,13 +189,11 @@ void
InferenceRequest::DecrementPendingRequestCount()
{
#ifdef TRITON_ENABLE_METRICS
// Only decrement if count has been incremented, and not already decremented.
if (decrement_pending_count_) {
auto reporter = model_raw_->MetricReporter();
if (reporter) {
reporter->DecrementGauge(kPendingRequestMetric, 1);
}
decrement_pending_count_ = false;
// Pending request count should always be 0 or 1 per-request. A request should
// not decrement the count unless it has already been incremented.
auto reporter = model_raw_->MetricReporter();
if (reporter) {
reporter->DecrementGauge(kPendingRequestMetric, 1);
}
#endif // TRITON_ENABLE_METRICS
}
Expand Down Expand Up @@ -376,7 +369,7 @@ InferenceRequest::OutputBufferProperties(
Status
InferenceRequest::Run(std::unique_ptr<InferenceRequest>& request)
{
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::STARTED));
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::PENDING));
return request->model_raw_->Enqueue(request);
}

Expand Down Expand Up @@ -826,6 +819,7 @@ InferenceRequest::PrepareForInference()
// inference execution.
inputs_.clear();
override_inputs_.clear();
ResetCancel();

// Renormalize if anything has changed in the inference request in a
// way that could impact renormalization.
Expand All @@ -849,8 +843,10 @@ InferenceRequest::PrepareForInference()
request_start_ns_ = 0;
#endif // TRITON_ENABLE_STATS

LOG_VERBOSE(1) << LogRequest() << "prepared: " << *this;
// Help enforce that PrepareForInference() is called prior to Run().
RETURN_IF_ERROR(SetState(InferenceRequest::State::INITIALIZED));

LOG_VERBOSE(1) << LogRequest() << "prepared: " << *this;
return Status::Success;
}

Expand Down Expand Up @@ -1580,8 +1576,8 @@ operator<<(std::ostream& out, const InferenceRequest::State& state)
out << "INITIALIZED";
break;
}
case InferenceRequest::State::STARTED: {
out << "STARTED";
case InferenceRequest::State::PENDING: {
out << "PENDING";
break;
}
case InferenceRequest::State::EXECUTING: {
Expand Down
13 changes: 7 additions & 6 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class InferenceRequest {
INITIALIZED,

// The request has been enqueued, but is not yet executing.
STARTED,
PENDING,

// The request has been picked up by a backend model instance for execution,
// but hasn't been released yet.
Expand Down Expand Up @@ -291,7 +291,6 @@ class InferenceRequest {
const int64_t requested_model_version);

InferenceRequest(Model* model, const int64_t requested_model_version);
~InferenceRequest();

const std::string& ModelName() const;
int64_t RequestedModelVersion() const { return requested_model_version_; }
Expand Down Expand Up @@ -680,6 +679,11 @@ class InferenceRequest {
secondary_stats_aggregator_ = secondary_stats_aggregator;
}

void Cancel() { response_factory_->Cancel(); }
void ResetCancel() { response_factory_->ResetCancel(); }

bool IsCancelled() { return response_factory_->IsCancelled(); }

#endif // TRITON_ENABLE_STATS

private:
Expand Down Expand Up @@ -795,13 +799,10 @@ class InferenceRequest {
std::shared_ptr<SequenceStates> sequence_states_;

// The state of the request.
InferenceRequest::State state_;
std::atomic<InferenceRequest::State> state_;
// Whether this is a null request used for direct sequence batch padding or
// not.
bool null_request_;
// Catch-all to correctly decrement pending count if needed on destruction
// if request doesn't follow normal execution path (error, unused, ensembles)
bool decrement_pending_count_;
};

std::ostream& operator<<(std::ostream& out, const InferenceRequest& request);
Expand Down
9 changes: 8 additions & 1 deletion src/infer_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,16 @@ class InferenceResponseFactory {
std::unique_ptr<InferenceResponse>&&, const uint32_t)>& delegator)
: model_(model), id_(id), allocator_(allocator),
alloc_userp_(alloc_userp), response_fn_(response_fn),
response_userp_(response_userp), response_delegator_(delegator)
response_userp_(response_userp), response_delegator_(delegator),
is_cancelled_(false)
{
}

void Cancel() { is_cancelled_ = true; }
void ResetCancel() { is_cancelled_ = false; }

bool IsCancelled() { return is_cancelled_; }

const ResponseAllocator* Allocator() { return allocator_; }
void* AllocatorUserp() { return alloc_userp_; }

Expand Down Expand Up @@ -118,6 +124,7 @@ class InferenceResponseFactory {
std::function<void(std::unique_ptr<InferenceResponse>&&, const uint32_t)>
response_delegator_;

std::atomic<bool> is_cancelled_;

#ifdef TRITON_ENABLE_TRACING
// Inference trace associated with this response.
Expand Down
6 changes: 4 additions & 2 deletions src/status.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ TritonCodeToStatusCode(TRITONSERVER_Error_Code code)
return Status::Code::UNSUPPORTED;
case TRITONSERVER_ERROR_ALREADY_EXISTS:
return Status::Code::ALREADY_EXISTS;

case TRITONSERVER_ERROR_CANCELLED:
return Status::Code::CANCELLED;
default:
break;
}
Expand All @@ -74,7 +75,8 @@ StatusCodeToTritonCode(Status::Code status_code)
return TRITONSERVER_ERROR_UNSUPPORTED;
case Status::Code::ALREADY_EXISTS:
return TRITONSERVER_ERROR_ALREADY_EXISTS;

case Status::Code::CANCELLED:
return TRITONSERVER_ERROR_CANCELLED;
default:
break;
}
Expand Down
Loading