Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for http generate request cancellation and segfault fix #6591

Merged
merged 22 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions qa/L0_http/generate_endpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,18 @@ def test_complex_schema(self):
"attempt to access JSON non-string as string", r.json()["error"]
)

def test_close_connection_during_streaming(self):
# verify the responses are streamed as soon as it is generated
text = "hello world"
rep_count = 3
inputs = {"PROMPT": [text], "STREAM": True, "REPETITION": rep_count, "DELAY": 2}
res = self.generate_stream(self._model_name, inputs, stream=True)
# close connection while the responses are being generated
res.close()
# check server healthiness
health_url = "http://localhost:8000/v2/health/live"
requests.get(health_url).raise_for_status()

def test_parameters(self):
# Test reserved nested object for parameters
text = "hello world"
Expand Down
5 changes: 4 additions & 1 deletion qa/L0_http/generate_models/mock_llm/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def exec_decoupled(self, requests):
for _ in range(rep_count):
if delay is not None:
time.sleep(delay)
sender.send(response)
if not sender.is_cancelled():
sender.send(response)
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
else:
break
sender.send(
None
if not fail_last
Expand Down
2 changes: 1 addition & 1 deletion qa/L0_http/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ fi
## Python Unit Tests
TEST_RESULT_FILE='test_results.txt'
PYTHON_TEST=generate_endpoint_test.py
EXPECTED_NUM_TESTS=13
EXPECTED_NUM_TESTS=14
set +e
python $PYTHON_TEST >$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
Expand Down
107 changes: 86 additions & 21 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3509,7 +3509,7 @@ HTTPAPIServer::HandleInfer(
// HTTP request paused when creating inference request. Resume it on exit if
// this function returns early due to error. Otherwise resumed in callback.
bool connection_paused = true;
auto infer_request = CreateInferRequest(req);
auto infer_request = CreateInferRequest(req, irequest);
infer_request->trace_ = trace;

const char* request_id = "<id_unknown>";
Expand Down Expand Up @@ -3585,14 +3585,22 @@ HTTPAPIServer::HandleInfer(
}

void
HTTPAPIServer::OKReplyCallback(evthr_t* thr, void* arg, void* shared)
HTTPAPIServer::InferRequestClass::OKReplyCallback(
evthr_t* thr, void* arg, void* shared)
{
HTTPAPIServer::InferRequestClass* infer_request =
reinterpret_cast<HTTPAPIServer::InferRequestClass*>(arg);

evhtp_request_t* request = infer_request->EvHtpRequest();
evhtp_send_reply(request, EVHTP_RES_OK);
evhtp_request_resume(request);

if (InferRequestClass::active_requests_.count(request) == 0) {
if (infer_request->triton_request_ != nullptr) {
TRITONSERVER_InferenceRequestCancel(infer_request->triton_request_);
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
evhtp_send_reply(request, EVHTP_RES_OK);
evhtp_request_resume(request);
}

#ifdef TRITON_ENABLE_TRACING
if (infer_request->trace_ != nullptr) {
Expand All @@ -3607,14 +3615,22 @@ HTTPAPIServer::OKReplyCallback(evthr_t* thr, void* arg, void* shared)
}

void
HTTPAPIServer::BADReplyCallback(evthr_t* thr, void* arg, void* shared)
HTTPAPIServer::InferRequestClass::BADReplyCallback(
evthr_t* thr, void* arg, void* shared)
{
HTTPAPIServer::InferRequestClass* infer_request =
reinterpret_cast<HTTPAPIServer::InferRequestClass*>(arg);

evhtp_request_t* request = infer_request->EvHtpRequest();
evhtp_send_reply(request, EVHTP_RES_BADREQ);
evhtp_request_resume(request);

if (InferRequestClass::active_requests_.count(request) == 0) {
if (infer_request->triton_request_ != nullptr) {
TRITONSERVER_InferenceRequestCancel(infer_request->triton_request_);
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
evhtp_send_reply(request, EVHTP_RES_BADREQ);
evhtp_request_resume(request);
}

#ifdef TRITON_ENABLE_TRACING
if (infer_request->trace_ != nullptr) {
Expand All @@ -3628,15 +3644,32 @@ HTTPAPIServer::BADReplyCallback(evthr_t* thr, void* arg, void* shared)
delete infer_request;
}

std::unordered_set<evhtp_request*>
HTTPAPIServer::InferRequestClass::active_requests_ = {};

evhtp_res
HTTPAPIServer::InferRequestClass::RequestFiniHook(
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
evhtp_request* request, void* arg)
{
InferRequestClass::active_requests_.erase(request);
return EVHTP_RES_OK;
}

HTTPAPIServer::InferRequestClass::InferRequestClass(
TRITONSERVER_Server* server, evhtp_request_t* req,
DataCompressor::Type response_compression_type)
DataCompressor::Type response_compression_type,
TRITONSERVER_InferenceRequest* triton_request)
: server_(server), req_(req),
response_compression_type_(response_compression_type), response_count_(0)
response_compression_type_(response_compression_type), response_count_(0),
triton_request_(triton_request)
{
evhtp_connection_t* htpconn = evhtp_request_get_connection(req);
thread_ = htpconn->thread;
evhtp_request_pause(req);
evhtp_request_set_hook(
req, evhtp_hook_on_request_fini, (evhtp_hook)(void*)RequestFiniHook,
reinterpret_cast<void*>(this));
InferRequestClass::active_requests_.emplace(req);
}

void
Expand Down Expand Up @@ -3702,11 +3735,15 @@ HTTPAPIServer::InferRequestClass::InferResponseComplete(
#endif // TRITON_ENABLE_TRACING

if (err == nullptr) {
evthr_defer(infer_request->thread_, OKReplyCallback, infer_request);
evthr_defer(
infer_request->thread_, InferRequestClass::OKReplyCallback,
infer_request);
} else {
EVBufferAddErrorJson(infer_request->req_->buffer_out, err);
TRITONSERVER_ErrorDelete(err);
evthr_defer(infer_request->thread_, BADReplyCallback, infer_request);
evthr_defer(
infer_request->thread_, InferRequestClass::BADReplyCallback,
infer_request);
}

LOG_TRITONSERVER_ERROR(
Expand Down Expand Up @@ -4054,8 +4091,9 @@ HTTPAPIServer::GenerateRequestClass::InferResponseComplete(
// First response starts the chunked response, the response code is set here
// so user should check response body in case of error at later time.
if (infer_request->IncrementResponseCount() == 0) {
infer_request->StartResponse(
(err == nullptr) ? EVHTP_RES_OK : EVHTP_RES_BADREQ);
infer_request->response_code_ =
(err == nullptr) ? EVHTP_RES_OK : EVHTP_RES_BADREQ;
evthr_defer(infer_request->thread_, StartResponse, infer_request);
}

#ifdef TRITON_ENABLE_TRACING
Expand All @@ -4078,15 +4116,27 @@ HTTPAPIServer::GenerateRequestClass::InferResponseComplete(
}

void
HTTPAPIServer::GenerateRequestClass::StartResponse(evhtp_res code)
HTTPAPIServer::GenerateRequestClass::StartResponse(
evthr_t* thr, void* arg, void* shared)
{
if (streaming_) {
AddContentTypeHeader(req_, "text/event-stream; charset=utf-8");
auto infer_request =
reinterpret_cast<HTTPAPIServer::GenerateRequestClass*>(arg);
auto req = infer_request->EvHtpRequest();

if (InferRequestClass::active_requests_.count(req) == 0) {
if (infer_request->triton_request_ != nullptr) {
TRITONSERVER_InferenceRequestCancel(infer_request->triton_request_);
}
return;
}

if (infer_request->streaming_) {
AddContentTypeHeader(req, "text/event-stream; charset=utf-8");
} else {
AddContentTypeHeader(req_, "application/json");
AddContentTypeHeader(req, "application/json");
}
evhtp_send_reply_chunk_start(req_, code);
evhtp_request_resume(req_);
evhtp_send_reply_chunk_start(req, infer_request->response_code_);
evhtp_request_resume(req);
}

void
Expand All @@ -4095,6 +4145,14 @@ HTTPAPIServer::GenerateRequestClass::ChunkResponseCallback(
{
auto infer_request =
reinterpret_cast<HTTPAPIServer::GenerateRequestClass*>(arg);

if (InferRequestClass::active_requests_.count(infer_request->req_) == 0) {
if (infer_request->triton_request_ != nullptr) {
TRITONSERVER_InferenceRequestCancel(infer_request->triton_request_);
}
return;
}

infer_request->SendChunkResponse(false /* end */);
}

Expand All @@ -4105,8 +4163,15 @@ HTTPAPIServer::GenerateRequestClass::EndResponseCallback(
auto infer_request =
reinterpret_cast<HTTPAPIServer::GenerateRequestClass*>(arg);

infer_request->SendChunkResponse(true /* end */);
evhtp_send_reply_chunk_end(infer_request->EvHtpRequest());
if (InferRequestClass::active_requests_.count(infer_request->req_) == 0) {
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
if (infer_request->triton_request_ != nullptr) {
TRITONSERVER_InferenceRequestCancel(infer_request->triton_request_);
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
infer_request->SendChunkResponse(true /* end */);
evhtp_send_reply_chunk_end(infer_request->EvHtpRequest());
nnshah1 marked this conversation as resolved.
Show resolved Hide resolved
}

delete infer_request;
}

Expand Down
40 changes: 26 additions & 14 deletions src/http_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>

#include "common.h"
#include "data_compressor.h"
Expand Down Expand Up @@ -222,7 +223,8 @@ class HTTPAPIServer : public HTTPServer {
// buffer in HTTPServer code.
explicit InferRequestClass(
TRITONSERVER_Server* server, evhtp_request_t* req,
DataCompressor::Type response_compression_type);
DataCompressor::Type response_compression_type,
TRITONSERVER_InferenceRequest* triton_request);
virtual ~InferRequestClass() = default;

evhtp_request_t* EvHtpRequest() const { return req_; }
Expand Down Expand Up @@ -253,6 +255,9 @@ class HTTPAPIServer : public HTTPServer {
// lifetime of the request.
std::list<std::vector<char>> serialized_data_;

static void OKReplyCallback(evthr_t* thr, void* arg, void* shared);
static void BADReplyCallback(evthr_t* thr, void* arg, void* shared);

protected:
TRITONSERVER_Server* server_;
evhtp_request_t* req_;
Expand All @@ -262,6 +267,18 @@ class HTTPAPIServer : public HTTPServer {

// Counter to keep track of number of responses generated.
std::atomic<uint32_t> response_count_;

// Event hook for called before request deletion
static evhtp_res RequestFiniHook(evhtp_request* req, void* arg);

// Set to keep track of active requests.
// To maintain thread safety must only be manipulated on event thread
static std::unordered_set<evhtp_request*> active_requests_;

// Pointer to associated Triton request, this class does not own the
// request and must not reference it after a successful
// TRITONSERVER_ServerInferAsync (except for cancellation).
TRITONSERVER_InferenceRequest* triton_request_{nullptr};
};

class GenerateRequestClass : public InferRequestClass {
Expand All @@ -272,9 +289,10 @@ class HTTPAPIServer : public HTTPServer {
const MappingSchema* request_schema,
const MappingSchema* response_schema, bool streaming,
TRITONSERVER_InferenceRequest* triton_request)
: InferRequestClass(server, req, response_compression_type),
: InferRequestClass(
server, req, response_compression_type, triton_request),
request_schema_(request_schema), response_schema_(response_schema),
streaming_(streaming), triton_request_(triton_request)
streaming_(streaming)
{
}
virtual ~GenerateRequestClass();
Expand All @@ -293,7 +311,7 @@ class HTTPAPIServer : public HTTPServer {
TRITONSERVER_Error* FinalizeResponse(
TRITONSERVER_InferenceResponse* response) override;
void AddErrorJson(TRITONSERVER_Error* error);
void StartResponse(evhtp_res code);
static void StartResponse(evthr_t* thr, void* arg, void* shared);

// [DLIS-5551] currently always performs basic conversion, only maps schema
// of EXACT_MAPPING kind. MAPPING_SCHEMA and upcoming kinds are for
Expand Down Expand Up @@ -339,10 +357,6 @@ class HTTPAPIServer : public HTTPServer {
const MappingSchema* request_schema_{nullptr};
const MappingSchema* response_schema_{nullptr};
const bool streaming_{false};
// Pointer to associated Triton request, this class does not own the
// request and must not reference it after a successful
// TRITONSERVER_ServerInferAsync.
TRITONSERVER_InferenceRequest* triton_request_{nullptr};
// Placeholder to completing response, this class does not own
// the response.
TRITONSERVER_InferenceResponse* triton_response_{nullptr};
Expand All @@ -352,6 +366,8 @@ class HTTPAPIServer : public HTTPServer {
std::mutex res_mtx_;
std::queue<evbuffer*> pending_http_responses_;
bool end_{false};
// starting response code
evhtp_res response_code_;
};

protected:
Expand All @@ -366,10 +382,10 @@ class HTTPAPIServer : public HTTPServer {
virtual void Handle(evhtp_request_t* req) override;
// [FIXME] extract to "infer" class
virtual std::unique_ptr<InferRequestClass> CreateInferRequest(
evhtp_request_t* req)
evhtp_request_t* req, TRITONSERVER_InferenceRequest* triton_request)
{
return std::unique_ptr<InferRequestClass>(new InferRequestClass(
server_.get(), req, GetResponseCompressionType(req)));
server_.get(), req, GetResponseCompressionType(req), triton_request));
}

// Helper function to retrieve infer request header in the form specified by
Expand Down Expand Up @@ -495,10 +511,6 @@ class HTTPAPIServer : public HTTPServer {
triton::common::TritonJson::Value& request_json,
TRITONSERVER_InferenceRequest* irequest);


static void OKReplyCallback(evthr_t* thr, void* arg, void* shared);
static void BADReplyCallback(evthr_t* thr, void* arg, void* shared);

std::shared_ptr<TRITONSERVER_Server> server_;

// Storing server metadata as it is consistent during server running
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ SagemakerAPIServer::SageMakerMMEHandleInfer(
if (err == nullptr) {
connection_paused = true;

auto infer_request = CreateInferRequest(req);
auto infer_request = CreateInferRequest(req, irequest);
#ifdef TRITON_ENABLE_TRACING
infer_request->trace_ = trace;
#endif // TRITON_ENABLE_TRACING
Expand Down
11 changes: 7 additions & 4 deletions src/sagemaker_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ class SagemakerAPIServer : public HTTPAPIServer {
public:
explicit SagemakeInferRequestClass(
TRITONSERVER_Server* server, evhtp_request_t* req,
DataCompressor::Type response_compression_type)
: InferRequestClass(server, req, response_compression_type)
DataCompressor::Type response_compression_type,
TRITONSERVER_InferenceRequest* triton_request)
: InferRequestClass(
server, req, response_compression_type, triton_request)
{
}
using InferRequestClass::InferResponseComplete;
Expand Down Expand Up @@ -121,10 +123,11 @@ class SagemakerAPIServer : public HTTPAPIServer {
static void BADReplyCallback507(evthr_t* thr, void* arg, void* shared);

std::unique_ptr<InferRequestClass> CreateInferRequest(
evhtp_request_t* req) override
evhtp_request_t* req,
TRITONSERVER_InferenceRequest* triton_request) override
{
return std::unique_ptr<InferRequestClass>(new SagemakeInferRequestClass(
server_.get(), req, GetResponseCompressionType(req)));
server_.get(), req, GetResponseCompressionType(req), triton_request));
}
TRITONSERVER_Error* GetInferenceHeaderLength(
evhtp_request_t* req, int32_t content_length,
Expand Down
Loading