Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend request objects lifetime and fixes possible segmentation fault #6620

Merged
merged 9 commits into from
Nov 22, 2023
27 changes: 18 additions & 9 deletions src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,10 +640,11 @@ InferRequestComplete(
{
LOG_VERBOSE(1) << "ModelInferHandler::InferRequestComplete";

RequestReleasePayload* request_release_payload =
static_cast<RequestReleasePayload*>(userp);

if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting GRPC inference request");
delete request_release_payload;
}
}

Expand Down Expand Up @@ -861,6 +862,12 @@ ModelInferHandler::Execute(InferHandler::State* state)
}

if (err == nullptr) {
state->inference_request_ = {
irequest, [](TRITONSERVER_InferenceRequest* request) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting gRPC inference request");
}};
err = SetInferenceRequestMetadata(irequest, request, state->parameters_);
}

Expand All @@ -881,9 +888,13 @@ ModelInferHandler::Execute(InferHandler::State* state)
tritonserver_, shm_manager_, request, std::move(serialized_data),
response_queue, &state->alloc_payload_);
}

auto request_release_payload =
std::make_unique<RequestReleasePayload>(state->inference_request_);
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest, InferRequestComplete, nullptr /* request_release_userp */);
irequest, InferRequestComplete,
request_release_payload.get() /* request_release_userp */);
}
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetResponseCallback(
Expand Down Expand Up @@ -922,16 +933,14 @@ ModelInferHandler::Execute(InferHandler::State* state)
// COMPLETE or CANCELLED. Recording the state and the irequest
// to handle gRPC stream cancellation.
if (err == nullptr) {
state->context_->InsertInflightState(state, irequest);
state->context_->InsertInflightState(state);
// The payload will be cleaned in request release callback.
request_release_payload.release();
} else {
// If error go immediately to COMPLETE.
LOG_VERBOSE(1) << "[request id: " << request_id << "] "
<< "Infer failed: " << TRITONSERVER_ErrorMessage(err);

LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(irequest),
"deleting GRPC inference request");

::grpc::Status status;
GrpcStatusUtil::Create(&status, err);
TRITONSERVER_ErrorDelete(err);
Expand Down
31 changes: 20 additions & 11 deletions src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ class Barrier {
size_t generation_;
};

// Simple structure that carries the userp payload needed for
// request release callback.
struct RequestReleasePayload final {
explicit RequestReleasePayload(
const std::shared_ptr<TRITONSERVER_InferenceRequest>& inference_request)
: inference_request_(inference_request){};

private:
std::shared_ptr<TRITONSERVER_InferenceRequest> inference_request_ = nullptr;
};

//
// ResponseQueue
//
Expand Down Expand Up @@ -715,15 +726,9 @@ class InferHandlerState {
// Inserts the state to a set tracking active requests
// within the server core. Should only be called when
// the request was successfully enqueued on Triton.
void InsertInflightState(
InferHandlerStateType* state, TRITONSERVER_InferenceRequest* irequest)
void InsertInflightState(InferHandlerStateType* state)
{
std::lock_guard<std::recursive_mutex> lock(mu_);
// The irequest_ptr_ will get populated when it is
// marked as active which means the request has been
// successfully enqueued to Triton core using
// TRITONSERVER_ServerInferAsync.
state->irequest_ptr_ = irequest;
inflight_states_.insert(state);
}

Expand All @@ -748,7 +753,7 @@ class InferHandlerState {
if (state->step_ != Steps::CANCELLED &&
state->step_ != Steps::COMPLETE) {
LOG_VERBOSE(1) << "Issuing cancellation for " << state->unique_id_;
if (state->irequest_ptr_ == nullptr) {
if (state->inference_request_.get() == nullptr) {
// The context might be holding some states that have
// not been issued to Triton core. Need to skip calling
// issuing cancellation for such requests.
Expand All @@ -758,7 +763,8 @@ class InferHandlerState {
// Assuming if RequestComplete callback is run asynchronously
// before this point.
TRITONSERVER_Error* err = nullptr;
err = TRITONSERVER_InferenceRequestCancel(state->irequest_ptr_);
err = TRITONSERVER_InferenceRequestCancel(
state->inference_request_.get());
// TODO: Add request id to the message
if (err != nullptr) {
LOG_INFO << "Failed to cancel the request: "
Expand Down Expand Up @@ -1023,7 +1029,6 @@ class InferHandlerState {
unique_id_ = NEXT_UNIQUE_ID;
context_ = context;
step_ = start_step;
irequest_ptr_ = nullptr;
cb_count_ = 0;
is_decoupled_ = false;
complete_ = false;
Expand All @@ -1042,6 +1047,7 @@ class InferHandlerState {
void Release()
{
context_ = nullptr;
inference_request_.reset();
ClearTraceTimestamps();
}

Expand Down Expand Up @@ -1077,7 +1083,10 @@ class InferHandlerState {
Steps step_;
std::recursive_mutex step_mtx_;

TRITONSERVER_InferenceRequest* irequest_ptr_;
// Shared pointer to the inference request object. The lifetime of
// inference request object is extended till all the responses from
// the request are processed and the request is released.
std::shared_ptr<TRITONSERVER_InferenceRequest> inference_request_;

#ifdef TRITON_ENABLE_TRACING
std::shared_ptr<TraceManager::Trace> trace_;
Expand Down
20 changes: 14 additions & 6 deletions src/grpc/stream_infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
}

if (err == nullptr) {
state->inference_request_ = {
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
irequest, [](TRITONSERVER_InferenceRequest* request) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting gRPC inference request");
}};
err = SetInferenceRequestMetadata(irequest, request, state->parameters_);
}

Expand All @@ -285,9 +291,13 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
tritonserver_, shm_manager_, request, std::move(serialized_data),
response_queue_, &state->alloc_payload_);
}

auto request_release_payload =
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
std::make_unique<RequestReleasePayload>(state->inference_request_);
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest, InferRequestComplete, nullptr /* request_release_userp */);
irequest, InferRequestComplete,
request_release_payload.get() /* request_release_userp */);
}
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetResponseCallback(
Expand Down Expand Up @@ -317,7 +327,9 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
// WRITEREADY or WRITTEN or CANCELLED. Recording the state and the
// irequest to handle gRPC stream cancellation.
if (err == nullptr) {
state->context_->InsertInflightState(state, irequest);
state->context_->InsertInflightState(state);
// The payload will be cleaned in request release callback.
request_release_payload.release();
} else {
// If there was an error then enqueue the error response and show
// it to be ready for writing.
Expand All @@ -337,10 +349,6 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
LOG_VERBOSE(1) << "[request id: " << log_request_id << "] "
<< "Infer failed: " << TRITONSERVER_ErrorMessage(err);

LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(irequest),
"deleting GRPC inference request");

::grpc::Status status;
GrpcStatusUtil::Create(&status, err);
TRITONSERVER_ErrorDelete(err);
Expand Down
Loading