From 71243fb1184bea499288eed263408cee9afce7a5 Mon Sep 17 00:00:00 2001 From: GuanLuo <41310872+GuanLuo@users.noreply.github.com> Date: Wed, 1 Mar 2023 12:50:00 -0800 Subject: [PATCH] GRPC endpoint clean up. Add missing protocol in doc (#5428) * GRPC endpoint clean up. Add missing protocol in doc * Remove GRPC prefix in grpc_server.h/.cc --- docs/protocol/README.md | 1 + src/grpc_server.cc | 756 ++++++++++++++++++++++------------------ src/grpc_server.h | 102 +++--- src/main.cc | 93 ++--- 4 files changed, 503 insertions(+), 449 deletions(-) diff --git a/docs/protocol/README.md b/docs/protocol/README.md index ccc080519b..d28dc2b06a 100644 --- a/docs/protocol/README.md +++ b/docs/protocol/README.md @@ -43,6 +43,7 @@ plus several extensions that are defined in the following documents: - [Shared-memory extension](./extension_shared_memory.md) - [Statistics extension](./extension_statistics.md) - [Trace extension](./extension_trace.md) +- [Logging extension](./extension_logging.md) For the GRPC protocol, the [protobuf specification](https://github.com/triton-inference-server/common/blob/main/protobuf/grpc_service.proto) diff --git a/src/grpc_server.cc b/src/grpc_server.cc index 757fd0fada..97bbd0628a 100644 --- a/src/grpc_server.cc +++ b/src/grpc_server.cc @@ -61,7 +61,7 @@ #define REGISTER_GRPC_INFER_THREAD_COUNT 2 -namespace triton { namespace server { +namespace triton { namespace server { namespace grpc { namespace { // Unique IDs are only needed when debugging. They only appear in @@ -111,45 +111,45 @@ class Barrier { // class GrpcStatusUtil { public: - static void Create(grpc::Status* status, TRITONSERVER_Error* err); - static grpc::StatusCode CodeToStatus(TRITONSERVER_Error_Code code); + static void Create(::grpc::Status* status, TRITONSERVER_Error* err); + static ::grpc::StatusCode CodeToStatus(TRITONSERVER_Error_Code code); }; void -GrpcStatusUtil::Create(grpc::Status* status, TRITONSERVER_Error* err) +GrpcStatusUtil::Create(::grpc::Status* status, TRITONSERVER_Error* err) { if (err == nullptr) { - *status = grpc::Status::OK; + *status = ::grpc::Status::OK; } else { - *status = grpc::Status( + *status = ::grpc::Status( GrpcStatusUtil::CodeToStatus(TRITONSERVER_ErrorCode(err)), TRITONSERVER_ErrorMessage(err)); } } -grpc::StatusCode +::grpc::StatusCode GrpcStatusUtil::CodeToStatus(TRITONSERVER_Error_Code code) { // GRPC status codes: // https://github.com/grpc/grpc/blob/master/include/grpc/impl/codegen/status.h switch (code) { case TRITONSERVER_ERROR_UNKNOWN: - return grpc::StatusCode::UNKNOWN; + return ::grpc::StatusCode::UNKNOWN; case TRITONSERVER_ERROR_INTERNAL: - return grpc::StatusCode::INTERNAL; + return ::grpc::StatusCode::INTERNAL; case TRITONSERVER_ERROR_NOT_FOUND: - return grpc::StatusCode::NOT_FOUND; + return ::grpc::StatusCode::NOT_FOUND; case TRITONSERVER_ERROR_INVALID_ARG: - return grpc::StatusCode::INVALID_ARGUMENT; + return ::grpc::StatusCode::INVALID_ARGUMENT; case TRITONSERVER_ERROR_UNAVAILABLE: - return grpc::StatusCode::UNAVAILABLE; + return ::grpc::StatusCode::UNAVAILABLE; case TRITONSERVER_ERROR_UNSUPPORTED: - return grpc::StatusCode::UNIMPLEMENTED; + return ::grpc::StatusCode::UNIMPLEMENTED; case TRITONSERVER_ERROR_ALREADY_EXISTS: - return grpc::StatusCode::ALREADY_EXISTS; + return ::grpc::StatusCode::ALREADY_EXISTS; } - return grpc::StatusCode::UNKNOWN; + return ::grpc::StatusCode::UNKNOWN; } // The step of processing that the state is in. Every state must @@ -206,18 +206,18 @@ operator<<(std::ostream& out, const Steps& step) //========================================================================= template -class CommonCallData : public GRPCServer::ICallData { +class CommonCallData : public Server::ICallData { public: using StandardRegisterFunc = std::function; + ::grpc::ServerContext*, RequestType*, ResponderType*, void*)>; using StandardCallbackFunc = - std::function; + std::function; CommonCallData( const std::string& name, const uint64_t id, const StandardRegisterFunc OnRegister, const StandardCallbackFunc OnExecute, const bool async, - grpc::ServerCompletionQueue* cq) + ::grpc::ServerCompletionQueue* cq) : name_(name), id_(id), OnRegister_(OnRegister), OnExecute_(OnExecute), async_(async), cq_(cq), responder_(&ctx_), step_(Steps::START) { @@ -248,15 +248,15 @@ class CommonCallData : public GRPCServer::ICallData { const StandardRegisterFunc OnRegister_; const StandardCallbackFunc OnExecute_; const bool async_; - grpc::ServerCompletionQueue* cq_; + ::grpc::ServerCompletionQueue* cq_; - grpc::ServerContext ctx_; - grpc::Alarm alarm_; + ::grpc::ServerContext ctx_; + ::grpc::Alarm alarm_; ResponderType responder_; RequestType request_; ResponseType response_; - grpc::Status status_; + ::grpc::Status status_; std::thread async_thread_; @@ -345,7 +345,7 @@ CommonCallData::WriteResponse() // // A common handler for all non-inference requests. // -class CommonHandler : public GRPCServer::HandlerBase { +class CommonHandler : public Server::HandlerBase { public: CommonHandler( const std::string& name, @@ -353,21 +353,42 @@ class CommonHandler : public GRPCServer::HandlerBase { const std::shared_ptr& shm_manager, TraceManager* trace_manager, inference::GRPCInferenceService::AsyncService* service, - grpc::health::v1::Health::AsyncService* health_service, - grpc::ServerCompletionQueue* cq); + ::grpc::health::v1::Health::AsyncService* health_service, + ::grpc::ServerCompletionQueue* cq); // Descriptive name of of the handler. const std::string& Name() const { return name_; } // Start handling requests. - void Start(); + void Start() override; // Stop handling requests. - void Stop(); + void Stop() override; private: void SetUpAllRequests(); + // [FIXME] turn into generated code + void RegisterServerLive(); + void RegisterServerReady(); + void RegisterHealthCheck(); + void RegisterModelReady(); + void RegisterServerMetadata(); + void RegisterModelMetadata(); + void RegisterModelConfig(); + void RegisterModelStatistics(); + void RegisterTrace(); + void RegisterLogging(); + void RegisterSystemSharedMemoryStatus(); + void RegisterSystemSharedMemoryRegister(); + void RegisterSystemSharedMemoryUnregister(); + void RegisterCudaSharedMemoryStatus(); + void RegisterCudaSharedMemoryRegister(); + void RegisterCudaSharedMemoryUnregister(); + void RegisterRepositoryIndex(); + void RegisterRepositoryModelLoad(); + void RegisterRepositoryModelUnload(); + const std::string name_; std::shared_ptr tritonserver_; @@ -375,8 +396,8 @@ class CommonHandler : public GRPCServer::HandlerBase { TraceManager* trace_manager_; inference::GRPCInferenceService::AsyncService* service_; - grpc::health::v1::Health::AsyncService* health_service_; - grpc::ServerCompletionQueue* cq_; + ::grpc::health::v1::Health::AsyncService* health_service_; + ::grpc::ServerCompletionQueue* cq_; std::unique_ptr thread_; }; @@ -386,8 +407,8 @@ CommonHandler::CommonHandler( const std::shared_ptr& shm_manager, TraceManager* trace_manager, inference::GRPCInferenceService::AsyncService* service, - grpc::health::v1::Health::AsyncService* health_service, - grpc::ServerCompletionQueue* cq) + ::grpc::health::v1::Health::AsyncService* health_service, + ::grpc::ServerCompletionQueue* cq) : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), trace_manager_(trace_manager), service_(service), health_service_(health_service), cq_(cq) @@ -409,8 +430,7 @@ CommonHandler::Start() bool ok; while (cq_->Next(&tag, &ok)) { - GRPCServer::ICallData* call_data = - static_cast(tag); + Server::ICallData* call_data = static_cast(tag); if (!call_data->Process(ok)) { LOG_VERBOSE(1) << "Done for " << call_data->Name() << ", " << call_data->Id(); @@ -438,7 +458,7 @@ CommonHandler::SetUpAllRequests() { // Define all the RPCs to be handled by this handler below // - // The format of each RPC specification is : + // Within each of the Register function, the format of RPC specification is: // 1. A OnRegister function: This will be called when the // server is ready to receive the requests for this RPC. // 2. A OnExecute function: This will be called when the @@ -446,13 +466,52 @@ CommonHandler::SetUpAllRequests() // 3. Create a CommonCallData object with the above callback // functions - // - // ServerLive - // + // health (GRPC standard) + RegisterHealthCheck(); + // health (Triton) + RegisterServerLive(); + RegisterServerReady(); + RegisterModelReady(); + + // Metadata + RegisterServerMetadata(); + RegisterModelMetadata(); + + // model config + RegisterModelConfig(); + + // shared memory + // system.. + RegisterSystemSharedMemoryStatus(); + RegisterSystemSharedMemoryRegister(); + RegisterSystemSharedMemoryUnregister(); + // cuda.. + RegisterCudaSharedMemoryStatus(); + RegisterCudaSharedMemoryRegister(); + RegisterCudaSharedMemoryUnregister(); + + // model repository + RegisterRepositoryIndex(); + RegisterRepositoryModelLoad(); + RegisterRepositoryModelUnload(); + + // statistics + RegisterModelStatistics(); + + // trace + RegisterTrace(); + + // logging + RegisterLogging(); +} + +void +CommonHandler::RegisterServerLive() +{ auto OnRegisterServerLive = [this]( - grpc::ServerContext* ctx, inference::ServerLiveRequest* request, - grpc::ServerAsyncResponseWriter* + ::grpc::ServerContext* ctx, inference::ServerLiveRequest* request, + ::grpc::ServerAsyncResponseWriter* responder, void* tag) { this->service_->RequestServerLive( @@ -462,7 +521,7 @@ CommonHandler::SetUpAllRequests() auto OnExecuteServerLive = [this]( inference::ServerLiveRequest& request, inference::ServerLiveResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { bool live = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsLive(tritonserver_.get(), &live); @@ -474,18 +533,19 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::ServerLiveRequest, inference::ServerLiveResponse>( "ServerLive", 0, OnRegisterServerLive, OnExecuteServerLive, false /* async */, cq_); +} - // - // ServerReady - // +void +CommonHandler::RegisterServerReady() +{ auto OnRegisterServerReady = [this]( - grpc::ServerContext* ctx, inference::ServerReadyRequest* request, - grpc::ServerAsyncResponseWriter* + ::grpc::ServerContext* ctx, inference::ServerReadyRequest* request, + ::grpc::ServerAsyncResponseWriter* responder, void* tag) { this->service_->RequestServerReady( @@ -495,7 +555,7 @@ CommonHandler::SetUpAllRequests() auto OnExecuteServerReady = [this]( inference::ServerReadyRequest& request, inference::ServerReadyResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { bool ready = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); @@ -507,41 +567,43 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::ServerReadyRequest, inference::ServerReadyResponse>( "ServerReady", 0, OnRegisterServerReady, OnExecuteServerReady, false /* async */, cq_); +} - // - // Health Check - // - auto OnHealthRegisterCheck = +void +CommonHandler::RegisterHealthCheck() +{ + auto OnRegisterHealthCheck = [this]( - grpc::ServerContext* ctx, - grpc::health::v1::HealthCheckRequest* request, - grpc::ServerAsyncResponseWriter< - grpc::health::v1::HealthCheckResponse>* responder, + ::grpc::ServerContext* ctx, + ::grpc::health::v1::HealthCheckRequest* request, + ::grpc::ServerAsyncResponseWriter< + ::grpc::health::v1::HealthCheckResponse>* responder, void* tag) { this->health_service_->RequestCheck( ctx, request, responder, this->cq_, this->cq_, tag); }; - auto OnHealthExecuteCheck = [this]( - grpc::health::v1::HealthCheckRequest& request, - grpc::health::v1::HealthCheckResponse* + auto OnExecuteHealthCheck = [this]( + ::grpc::health::v1::HealthCheckRequest& + request, + ::grpc::health::v1::HealthCheckResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { bool live = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), &live); auto serving_status = - grpc::health::v1::HealthCheckResponse_ServingStatus_UNKNOWN; + ::grpc::health::v1::HealthCheckResponse_ServingStatus_UNKNOWN; if (err == nullptr) { serving_status = - live - ? grpc::health::v1::HealthCheckResponse_ServingStatus_SERVING - : grpc::health::v1::HealthCheckResponse_ServingStatus_NOT_SERVING; + live ? ::grpc::health::v1::HealthCheckResponse_ServingStatus_SERVING + : ::grpc::health::v1:: + HealthCheckResponse_ServingStatus_NOT_SERVING; } response->set_status(serving_status); @@ -550,19 +612,21 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, - grpc::health::v1::HealthCheckRequest, - grpc::health::v1::HealthCheckResponse>( - "Check", 0, OnHealthRegisterCheck, OnHealthExecuteCheck, + ::grpc::ServerAsyncResponseWriter< + ::grpc::health::v1::HealthCheckResponse>, + ::grpc::health::v1::HealthCheckRequest, + ::grpc::health::v1::HealthCheckResponse>( + "Check", 0, OnRegisterHealthCheck, OnExecuteHealthCheck, false /* async */, cq_); +} - // - // ModelReady - // +void +CommonHandler::RegisterModelReady() +{ auto OnRegisterModelReady = [this]( - grpc::ServerContext* ctx, inference::ModelReadyRequest* request, - grpc::ServerAsyncResponseWriter* + ::grpc::ServerContext* ctx, inference::ModelReadyRequest* request, + ::grpc::ServerAsyncResponseWriter* responder, void* tag) { this->service_->RequestModelReady( @@ -572,7 +636,7 @@ CommonHandler::SetUpAllRequests() auto OnExecuteModelReady = [this]( inference::ModelReadyRequest& request, inference::ModelReadyResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { bool is_ready = false; int64_t requested_model_version; auto err = @@ -590,18 +654,19 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::ModelReadyRequest, inference::ModelReadyResponse>( "ModelReady", 0, OnRegisterModelReady, OnExecuteModelReady, false /* async */, cq_); +} - // - // ServerMetadata - // +void +CommonHandler::RegisterServerMetadata() +{ auto OnRegisterServerMetadata = [this]( - grpc::ServerContext* ctx, inference::ServerMetadataRequest* request, - grpc::ServerAsyncResponseWriter* + ::grpc::ServerContext* ctx, inference::ServerMetadataRequest* request, + ::grpc::ServerAsyncResponseWriter* responder, void* tag) { this->service_->RequestServerMetadata( @@ -611,7 +676,7 @@ CommonHandler::SetUpAllRequests() auto OnExecuteServerMetadata = [this]( inference::ServerMetadataRequest& request, - inference::ServerMetadataResponse* response, grpc::Status* status) { + inference::ServerMetadataResponse* response, ::grpc::Status* status) { TRITONSERVER_Message* server_metadata_message = nullptr; TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( tritonserver_.get(), &server_metadata_message); @@ -665,18 +730,19 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::ServerMetadataRequest, inference::ServerMetadataResponse>( "ServerMetadata", 0, OnRegisterServerMetadata, OnExecuteServerMetadata, false /* async */, cq_); +} - // - // ModelMetadata - // +void +CommonHandler::RegisterModelMetadata() +{ auto OnRegisterModelMetadata = [this]( - grpc::ServerContext* ctx, inference::ModelMetadataRequest* request, - grpc::ServerAsyncResponseWriter* + ::grpc::ServerContext* ctx, inference::ModelMetadataRequest* request, + ::grpc::ServerAsyncResponseWriter* responder, void* tag) { this->service_->RequestModelMetadata( @@ -686,7 +752,7 @@ CommonHandler::SetUpAllRequests() auto OnExecuteModelMetadata = [this]( inference::ModelMetadataRequest& request, inference::ModelMetadataResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { int64_t requested_model_version; auto err = GetModelVersionFromString(request.version(), &requested_model_version); @@ -830,18 +896,19 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::ModelMetadataRequest, inference::ModelMetadataResponse>( "ModelMetadata", 0, OnRegisterModelMetadata, OnExecuteModelMetadata, false /* async */, cq_); +} - // - // ModelConfig - // +void +CommonHandler::RegisterModelConfig() +{ auto OnRegisterModelConfig = [this]( - grpc::ServerContext* ctx, inference::ModelConfigRequest* request, - grpc::ServerAsyncResponseWriter* + ::grpc::ServerContext* ctx, inference::ModelConfigRequest* request, + ::grpc::ServerAsyncResponseWriter* responder, void* tag) { this->service_->RequestModelConfig( @@ -851,7 +918,7 @@ CommonHandler::SetUpAllRequests() auto OnExecuteModelConfig = [this]( inference::ModelConfigRequest& request, inference::ModelConfigResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { int64_t requested_model_version; auto err = GetModelVersionFromString(request.version(), &requested_model_version); @@ -880,18 +947,20 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::ModelConfigRequest, inference::ModelConfigResponse>( "ModelConfig", 0, OnRegisterModelConfig, OnExecuteModelConfig, false /* async */, cq_); +} - // - // ModelStatistics - // +void +CommonHandler::RegisterModelStatistics() +{ auto OnRegisterModelStatistics = [this]( - grpc::ServerContext* ctx, inference::ModelStatisticsRequest* request, - grpc::ServerAsyncResponseWriter* + ::grpc::ServerContext* ctx, + inference::ModelStatisticsRequest* request, + ::grpc::ServerAsyncResponseWriter* responder, void* tag) { this->service_->RequestModelStatistics( @@ -903,7 +972,7 @@ CommonHandler::SetUpAllRequests() request, inference::ModelStatisticsResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { #ifdef TRITON_ENABLE_STATS triton::common::TritonJson::Value model_stats_json; @@ -1175,18 +1244,19 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::ModelStatisticsRequest, inference::ModelStatisticsResponse>( "ModelStatistics", 0, OnRegisterModelStatistics, OnExecuteModelStatistics, false /* async */, cq_); +} - // - // Trace - // +void +CommonHandler::RegisterTrace() +{ auto OnRegisterTrace = [this]( - grpc::ServerContext* ctx, inference::TraceSettingRequest* request, - grpc::ServerAsyncResponseWriter* + ::grpc::ServerContext* ctx, inference::TraceSettingRequest* request, + ::grpc::ServerAsyncResponseWriter* responder, void* tag) { this->service_->RequestTraceSetting( @@ -1196,7 +1266,7 @@ CommonHandler::SetUpAllRequests() auto OnExecuteTrace = [this]( inference::TraceSettingRequest& request, inference::TraceSettingResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { #ifdef TRITON_ENABLE_TRACING TRITONSERVER_Error* err = nullptr; TRITONSERVER_InferenceTraceLevel level = TRITONSERVER_TRACE_LEVEL_DISABLED; @@ -1385,18 +1455,18 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::TraceSettingRequest, inference::TraceSettingResponse>( "Trace", 0, OnRegisterTrace, OnExecuteTrace, false /* async */, cq_); +} - // - // Log Settings - // - +void +CommonHandler::RegisterLogging() +{ auto OnRegisterLogging = [this]( - grpc::ServerContext* ctx, inference::LogSettingsRequest* request, - grpc::ServerAsyncResponseWriter* + ::grpc::ServerContext* ctx, inference::LogSettingsRequest* request, + ::grpc::ServerAsyncResponseWriter* responder, void* tag) { this->service_->RequestLogSettings( @@ -1406,7 +1476,7 @@ CommonHandler::SetUpAllRequests() auto OnExecuteLogging = [this]( inference::LogSettingsRequest& request, inference::LogSettingsResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { #ifdef TRITON_ENABLE_LOGGING TRITONSERVER_Error* err = nullptr; @@ -1591,20 +1661,20 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::LogSettingsRequest, inference::LogSettingsResponse>( "Logging", 0, OnRegisterLogging, OnExecuteLogging, false /* async */, cq_); +} - - // - // SystemSharedMemoryStatus - // +void +CommonHandler::RegisterSystemSharedMemoryStatus() +{ auto OnRegisterSystemSharedMemoryStatus = [this]( - grpc::ServerContext* ctx, + ::grpc::ServerContext* ctx, inference::SystemSharedMemoryStatusRequest* request, - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryStatusResponse>* responder, void* tag) { this->service_->RequestSystemSharedMemoryStatus( @@ -1615,7 +1685,7 @@ CommonHandler::SetUpAllRequests() [this]( inference::SystemSharedMemoryStatusRequest& request, inference::SystemSharedMemoryStatusResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { triton::common::TritonJson::Value shm_status_json( triton::common::TritonJson::ValueType::ARRAY); TRITONSERVER_Error* err = shm_manager_->GetStatus( @@ -1661,22 +1731,22 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryStatusResponse>, inference::SystemSharedMemoryStatusRequest, inference::SystemSharedMemoryStatusResponse>( "SystemSharedMemoryStatus", 0, OnRegisterSystemSharedMemoryStatus, OnExecuteSystemSharedMemoryStatus, false /* async */, cq_); +} - - // - // SystemSharedMemoryRegister - // +void +CommonHandler::RegisterSystemSharedMemoryRegister() +{ auto OnRegisterSystemSharedMemoryRegister = [this]( - grpc::ServerContext* ctx, + ::grpc::ServerContext* ctx, inference::SystemSharedMemoryRegisterRequest* request, - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryRegisterResponse>* responder, void* tag) { this->service_->RequestSystemSharedMemoryRegister( @@ -1687,7 +1757,7 @@ CommonHandler::SetUpAllRequests() [this]( inference::SystemSharedMemoryRegisterRequest& request, inference::SystemSharedMemoryRegisterResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( request.name(), request.key(), request.offset(), request.byte_size()); @@ -1697,22 +1767,22 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryRegisterResponse>, inference::SystemSharedMemoryRegisterRequest, inference::SystemSharedMemoryRegisterResponse>( "SystemSharedMemoryRegister", 0, OnRegisterSystemSharedMemoryRegister, OnExecuteSystemSharedMemoryRegister, false /* async */, cq_); +} - - // - // SystemSharedMemoryUnregister - // +void +CommonHandler::RegisterSystemSharedMemoryUnregister() +{ auto OnRegisterSystemSharedMemoryUnregister = [this]( - grpc::ServerContext* ctx, + ::grpc::ServerContext* ctx, inference::SystemSharedMemoryUnregisterRequest* request, - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryUnregisterResponse>* responder, void* tag) { this->service_->RequestSystemSharedMemoryUnregister( @@ -1723,7 +1793,7 @@ CommonHandler::SetUpAllRequests() [this]( inference::SystemSharedMemoryUnregisterRequest& request, inference::SystemSharedMemoryUnregisterResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { TRITONSERVER_Error* err = nullptr; if (request.name().empty()) { err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_CPU); @@ -1737,22 +1807,22 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryUnregisterResponse>, inference::SystemSharedMemoryUnregisterRequest, inference::SystemSharedMemoryUnregisterResponse>( "SystemSharedMemoryUnregister", 0, OnRegisterSystemSharedMemoryUnregister, OnExecuteSystemSharedMemoryUnregister, false /* async */, cq_); +} - - // - // CudaSharedMemoryStatus - // +void +CommonHandler::RegisterCudaSharedMemoryStatus() +{ auto OnRegisterCudaSharedMemoryStatus = [this]( - grpc::ServerContext* ctx, + ::grpc::ServerContext* ctx, inference::CudaSharedMemoryStatusRequest* request, - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryStatusResponse>* responder, void* tag) { this->service_->RequestCudaSharedMemoryStatus( @@ -1762,7 +1832,7 @@ CommonHandler::SetUpAllRequests() [this]( inference::CudaSharedMemoryStatusRequest& request, inference::CudaSharedMemoryStatusResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { triton::common::TritonJson::Value shm_status_json( triton::common::TritonJson::ValueType::ARRAY); TRITONSERVER_Error* err = shm_manager_->GetStatus( @@ -1800,22 +1870,22 @@ CommonHandler::SetUpAllRequests() TRITONSERVER_ErrorDelete(err); }; new CommonCallData< - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryStatusResponse>, inference::CudaSharedMemoryStatusRequest, inference::CudaSharedMemoryStatusResponse>( "CudaSharedMemoryStatus", 0, OnRegisterCudaSharedMemoryStatus, OnExecuteCudaSharedMemoryStatus, false /* async */, cq_); +} - - // - // CudaSharedMemoryRegister - // +void +CommonHandler::RegisterCudaSharedMemoryRegister() +{ auto OnRegisterCudaSharedMemoryRegister = [this]( - grpc::ServerContext* ctx, + ::grpc::ServerContext* ctx, inference::CudaSharedMemoryRegisterRequest* request, - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryRegisterResponse>* responder, void* tag) { this->service_->RequestCudaSharedMemoryRegister( @@ -1826,7 +1896,7 @@ CommonHandler::SetUpAllRequests() [this]( inference::CudaSharedMemoryRegisterRequest& request, inference::CudaSharedMemoryRegisterResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { TRITONSERVER_Error* err = nullptr; #ifdef TRITON_ENABLE_GPU err = shm_manager_->RegisterCUDASharedMemory( @@ -1848,21 +1918,22 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryRegisterResponse>, inference::CudaSharedMemoryRegisterRequest, inference::CudaSharedMemoryRegisterResponse>( "CudaSharedMemoryRegister", 0, OnRegisterCudaSharedMemoryRegister, OnExecuteCudaSharedMemoryRegister, false /* async */, cq_); +} - // - // CudaSharedMemoryUnregister - // +void +CommonHandler::RegisterCudaSharedMemoryUnregister() +{ auto OnRegisterCudaSharedMemoryUnregister = [this]( - grpc::ServerContext* ctx, + ::grpc::ServerContext* ctx, inference::CudaSharedMemoryUnregisterRequest* request, - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryUnregisterResponse>* responder, void* tag) { this->service_->RequestCudaSharedMemoryUnregister( @@ -1873,7 +1944,7 @@ CommonHandler::SetUpAllRequests() [this]( inference::CudaSharedMemoryUnregisterRequest& request, inference::CudaSharedMemoryUnregisterResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { TRITONSERVER_Error* err = nullptr; if (request.name().empty()) { err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); @@ -1887,20 +1958,22 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryUnregisterResponse>, inference::CudaSharedMemoryUnregisterRequest, inference::CudaSharedMemoryUnregisterResponse>( "CudaSharedMemoryUnregister", 0, OnRegisterCudaSharedMemoryUnregister, OnExecuteCudaSharedMemoryUnregister, false /* async */, cq_); +} - // - // RepositoryIndex - // +void +CommonHandler::RegisterRepositoryIndex() +{ auto OnRegisterRepositoryIndex = [this]( - grpc::ServerContext* ctx, inference::RepositoryIndexRequest* request, - grpc::ServerAsyncResponseWriter* + ::grpc::ServerContext* ctx, + inference::RepositoryIndexRequest* request, + ::grpc::ServerAsyncResponseWriter* responder, void* tag) { this->service_->RequestRepositoryIndex( @@ -1910,7 +1983,8 @@ CommonHandler::SetUpAllRequests() auto OnExecuteRepositoryIndex = [this]( inference::RepositoryIndexRequest& request, - inference::RepositoryIndexResponse* response, grpc::Status* status) { + inference::RepositoryIndexResponse* response, + ::grpc::Status* status) { TRITONSERVER_Error* err = nullptr; if (request.repository_name().empty()) { uint32_t flags = 0; @@ -1986,19 +2060,20 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::RepositoryIndexRequest, inference::RepositoryIndexResponse>( "RepositoryIndex", 0, OnRegisterRepositoryIndex, OnExecuteRepositoryIndex, false /* async */, cq_); +} - // - // RepositoryModelLoad - // +void +CommonHandler::RegisterRepositoryModelLoad() +{ auto OnRegisterRepositoryModelLoad = [this]( - grpc::ServerContext* ctx, + ::grpc::ServerContext* ctx, inference::RepositoryModelLoadRequest* request, - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::RepositoryModelLoadResponse>* responder, void* tag) { this->service_->RequestRepositoryModelLoad( @@ -2009,7 +2084,7 @@ CommonHandler::SetUpAllRequests() [this]( inference::RepositoryModelLoadRequest& request, inference::RepositoryModelLoadResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { TRITONSERVER_Error* err = nullptr; if (request.repository_name().empty()) { std::vector params; @@ -2094,20 +2169,21 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::RepositoryModelLoadRequest, inference::RepositoryModelLoadResponse>( "RepositoryModelLoad", 0, OnRegisterRepositoryModelLoad, OnExecuteRepositoryModelLoad, true /* async */, cq_); +} - // - // RepositoryModelUnload - // +void +CommonHandler::RegisterRepositoryModelUnload() +{ auto OnRegisterRepositoryModelUnload = [this]( - grpc::ServerContext* ctx, + ::grpc::ServerContext* ctx, inference::RepositoryModelUnloadRequest* request, - grpc::ServerAsyncResponseWriter< + ::grpc::ServerAsyncResponseWriter< inference::RepositoryModelUnloadResponse>* responder, void* tag) { this->service_->RequestRepositoryModelUnload( @@ -2118,7 +2194,7 @@ CommonHandler::SetUpAllRequests() [this]( inference::RepositoryModelUnloadRequest& request, inference::RepositoryModelUnloadResponse* response, - grpc::Status* status) { + ::grpc::Status* status) { TRITONSERVER_Error* err = nullptr; if (request.repository_name().empty()) { // Check if the dependent models should be removed @@ -2159,7 +2235,8 @@ CommonHandler::SetUpAllRequests() }; new CommonCallData< - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter< + inference::RepositoryModelUnloadResponse>, inference::RepositoryModelUnloadRequest, inference::RepositoryModelUnloadResponse>( "RepositoryModelUnload", 0, OnRegisterRepositoryModelUnload, @@ -2375,11 +2452,11 @@ class InferHandlerState { // transaction (e.g. a stream). struct Context { explicit Context( - grpc::ServerCompletionQueue* cq, const uint64_t unique_id = 0) + ::grpc::ServerCompletionQueue* cq, const uint64_t unique_id = 0) : cq_(cq), unique_id_(unique_id), ongoing_requests_(0), step_(Steps::START), finish_ok_(true), ongoing_write_(false) { - ctx_.reset(new grpc::ServerContext()); + ctx_.reset(new ::grpc::ServerContext()); responder_.reset(new ServerResponderType(ctx_.get())); } @@ -2497,7 +2574,7 @@ class InferHandlerState { } // The grpc completion queue associated with the RPC. - grpc::ServerCompletionQueue* cq_; + ::grpc::ServerCompletionQueue* cq_; // Unique ID for the context. Used only for debugging so will // always be 0 in non-debug builds. @@ -2506,7 +2583,7 @@ class InferHandlerState { // Context for the rpc, allowing to tweak aspects of it such as // the use of compression, authentication, as well as to send // metadata back to the client. - std::unique_ptr ctx_; + std::unique_ptr<::grpc::ServerContext> ctx_; std::unique_ptr responder_; // The states associated with this context that are currently @@ -2609,7 +2686,7 @@ class InferHandlerState { RequestType request_; std::shared_ptr> response_queue_; - grpc::Alarm alarm_; + ::grpc::Alarm alarm_; // For testing and debugging int delay_response_ms_; @@ -2625,12 +2702,12 @@ class InferHandlerState { template < typename ServiceType, typename ServerResponderType, typename RequestType, typename ResponseType> -class InferHandler : public GRPCServer::HandlerBase { +class InferHandler : public Server::HandlerBase { public: InferHandler( const std::string& name, const std::shared_ptr& tritonserver, - ServiceType* service, grpc::ServerCompletionQueue* cq, + ServiceType* service, ::grpc::ServerCompletionQueue* cq, size_t max_state_bucket_count); virtual ~InferHandler(); @@ -2638,10 +2715,10 @@ class InferHandler : public GRPCServer::HandlerBase { const std::string& Name() const { return name_; } // Start handling requests. - void Start(); + void Start() override; // Stop handling requests. - void Stop(); + void Stop() override; protected: using State = @@ -2694,7 +2771,7 @@ class InferHandler : public GRPCServer::HandlerBase { std::shared_ptr tritonserver_; ServiceType* service_; - grpc::ServerCompletionQueue* cq_; + ::grpc::ServerCompletionQueue* cq_; std::unique_ptr thread_; // Mutex to serialize State allocation @@ -2713,7 +2790,7 @@ InferHandler:: InferHandler( const std::string& name, const std::shared_ptr& tritonserver, - ServiceType* service, grpc::ServerCompletionQueue* cq, + ServiceType* service, ::grpc::ServerCompletionQueue* cq, size_t max_state_bucket_count) : name_(name), tritonserver_(tritonserver), service_(service), cq_(cq), max_state_bucket_count_(max_state_bucket_count) @@ -3765,7 +3842,7 @@ InferResponseCompleteCommon( class ModelInferHandler : public InferHandler< inference::GRPCInferenceService::AsyncService, - grpc::ServerAsyncResponseWriter, + ::grpc::ServerAsyncResponseWriter, inference::ModelInferRequest, inference::ModelInferResponse> { public: ModelInferHandler( @@ -3774,7 +3851,7 @@ class ModelInferHandler TraceManager* trace_manager, const std::shared_ptr& shm_manager, inference::GRPCInferenceService::AsyncService* service, - grpc::ServerCompletionQueue* cq, size_t max_state_bucket_count, + ::grpc::ServerCompletionQueue* cq, size_t max_state_bucket_count, grpc_compression_level compression_level) : InferHandler(name, tritonserver, service, cq, max_state_bucket_count), trace_manager_(trace_manager), shm_manager_(shm_manager), @@ -3972,7 +4049,7 @@ ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok) TRITONSERVER_InferenceRequestDelete(irequest), "deleting GRPC inference request"); - grpc::Status status; + ::grpc::Status status; GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); @@ -4057,7 +4134,7 @@ ModelInferHandler::InferResponseComplete( response->Clear(); } - grpc::Status status; + ::grpc::Status status; GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); @@ -4158,7 +4235,7 @@ StreamOutputBufferAttributes( class ModelStreamInferHandler : public InferHandler< inference::GRPCInferenceService::AsyncService, - grpc::ServerAsyncReaderWriter< + ::grpc::ServerAsyncReaderWriter< inference::ModelStreamInferResponse, inference::ModelInferRequest>, inference::ModelInferRequest, inference::ModelStreamInferResponse> { @@ -4169,7 +4246,7 @@ class ModelStreamInferHandler TraceManager* trace_manager, const std::shared_ptr& shm_manager, inference::GRPCInferenceService::AsyncService* service, - grpc::ServerCompletionQueue* cq, size_t max_state_bucket_count, + ::grpc::ServerCompletionQueue* cq, size_t max_state_bucket_count, grpc_compression_level compression_level) : InferHandler(name, tritonserver, service, cq, max_state_bucket_count), trace_manager_(trace_manager), shm_manager_(shm_manager), @@ -4283,8 +4360,8 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) state->context_->step_ = Steps::COMPLETE; state->step_ = Steps::COMPLETE; state->context_->responder_->Finish( - state->context_->finish_ok_ ? grpc::Status::OK - : grpc::Status::CANCELLED, + state->context_->finish_ok_ ? ::grpc::Status::OK + : ::grpc::Status::CANCELLED, state); } else { state->step_ = Steps::FINISH; @@ -4409,7 +4486,7 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) TRITONSERVER_InferenceRequestDelete(irequest), "deleting GRPC inference request"); - grpc::Status status; + ::grpc::Status status; GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); response->set_error_message(status.error_message()); @@ -4592,8 +4669,8 @@ ModelStreamInferHandler::Finish(InferHandler::State* state) state->context_->step_ = Steps::COMPLETE; state->step_ = Steps::COMPLETE; state->context_->responder_->Finish( - state->context_->finish_ok_ ? grpc::Status::OK - : grpc::Status::CANCELLED, + state->context_->finish_ok_ ? ::grpc::Status::OK + : ::grpc::Status::CANCELLED, state); } else { state->step_ = Steps::FINISH; @@ -4655,7 +4732,7 @@ ModelStreamInferHandler::StreamInferResponseComplete( } if (err != nullptr) { - grpc::Status status; + ::grpc::Status status; GrpcStatusUtil::Create(&status, err); response->mutable_infer_response()->Clear(); response->set_error_message(status.error_message()); @@ -4710,159 +4787,157 @@ ReadFile(const std::string& filename, std::string& data) } // namespace // -// GRPCServer +// Server // -GRPCServer::GRPCServer( - const std::shared_ptr& server, +Server::Server( + const std::shared_ptr& tritonserver, triton::server::TraceManager* trace_manager, const std::shared_ptr& shm_manager, - const std::string& server_addr, const bool reuse_port, bool use_ssl, - const SslOptions& ssl_options, const int infer_allocation_pool_size, - grpc_compression_level compression_level, - const KeepAliveOptions& keepalive_options) - : server_(server), trace_manager_(trace_manager), shm_manager_(shm_manager), - server_addr_(server_addr), reuse_port_(reuse_port), use_ssl_(use_ssl), - ssl_options_(ssl_options), - infer_allocation_pool_size_(infer_allocation_pool_size), - compression_level_(compression_level), - keepalive_options_(keepalive_options), running_(false) + const Options& options) + : tritonserver_(tritonserver), trace_manager_(trace_manager), + shm_manager_(shm_manager), server_addr_( + options.socket_.address_ + ":" + + std::to_string(options.socket_.port_)) { + std::shared_ptr<::grpc::ServerCredentials> credentials; + const auto& ssl_options = options.ssl_; + if (ssl_options.use_ssl_) { + std::string key; + std::string cert; + std::string root; + ReadFile(ssl_options.server_cert_, cert); + ReadFile(ssl_options.server_key_, key); + ReadFile(ssl_options.root_cert_, root); + ::grpc::SslServerCredentialsOptions::PemKeyCertPair keycert = {key, cert}; + ::grpc::SslServerCredentialsOptions sslOpts; + sslOpts.pem_root_certs = root; + sslOpts.pem_key_cert_pairs.push_back(keycert); + if (ssl_options.use_mutual_auth_) { + sslOpts.client_certificate_request = + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY; + } + credentials = ::grpc::SslServerCredentials(sslOpts); + } else { + credentials = ::grpc::InsecureServerCredentials(); + } + + builder_.AddListeningPort(server_addr_, credentials, &bound_port_); + builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); + builder_.RegisterService(&service_); + builder_.RegisterService(&health_service_); + builder_.AddChannelArgument( + GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); + + { + // GRPC KeepAlive Docs: + // https://grpc.github.io/grpc/cpp/md_doc_keepalive.html NOTE: In order to + // work properly, the client-side settings should be in agreement with + // server-side settings. + const auto& keepalive_options = options.keep_alive_; + builder_.AddChannelArgument( + GRPC_ARG_KEEPALIVE_TIME_MS, keepalive_options.keepalive_time_ms_); + builder_.AddChannelArgument( + GRPC_ARG_KEEPALIVE_TIMEOUT_MS, keepalive_options.keepalive_timeout_ms_); + builder_.AddChannelArgument( + GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, + keepalive_options.keepalive_permit_without_calls_); + builder_.AddChannelArgument( + GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, + keepalive_options.http2_max_pings_without_data_); + builder_.AddChannelArgument( + GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS, + keepalive_options.http2_min_recv_ping_interval_without_data_ms_); + builder_.AddChannelArgument( + GRPC_ARG_HTTP2_MAX_PING_STRIKES, + keepalive_options.http2_max_ping_strikes_); + + LOG_VERBOSE(1) << "=== GRPC KeepAlive Options ==="; + LOG_VERBOSE(1) << "keepalive_time_ms: " + << keepalive_options.keepalive_time_ms_; + LOG_VERBOSE(1) << "keepalive_timeout_ms: " + << keepalive_options.keepalive_timeout_ms_; + LOG_VERBOSE(1) << "keepalive_permit_without_calls: " + << keepalive_options.keepalive_permit_without_calls_; + LOG_VERBOSE(1) << "http2_max_pings_without_data: " + << keepalive_options.http2_max_pings_without_data_; + LOG_VERBOSE(1) + << "http2_min_recv_ping_interval_without_data_ms: " + << keepalive_options.http2_min_recv_ping_interval_without_data_ms_; + LOG_VERBOSE(1) << "http2_max_ping_strikes: " + << keepalive_options.http2_max_ping_strikes_; + LOG_VERBOSE(1) << "=============================="; + } + + common_cq_ = builder_.AddCompletionQueue(); + model_infer_cq_ = builder_.AddCompletionQueue(); + model_stream_infer_cq_ = builder_.AddCompletionQueue(); + + // A common Handler for other non-inference requests + common_handler_.reset(new CommonHandler( + "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, + &health_service_, common_cq_.get())); + + // Handler for model inference requests. + for (int i = 0; i < REGISTER_GRPC_INFER_THREAD_COUNT; ++i) { + model_infer_handlers_.emplace_back(new ModelInferHandler( + "ModelInferHandler", tritonserver_, trace_manager_, shm_manager_, + &service_, model_infer_cq_.get(), + options.infer_allocation_pool_size_ /* max_state_bucket_count */, + options.infer_compression_level_)); + } + + // Handler for streaming inference requests. Keeps one handler for streaming + // to avoid possible concurrent writes which is not allowed + model_stream_infer_handlers_.emplace_back(new ModelStreamInferHandler( + "ModelStreamInferHandler", tritonserver_, trace_manager_, shm_manager_, + &service_, model_stream_infer_cq_.get(), + options.infer_allocation_pool_size_ /* max_state_bucket_count */, + options.infer_compression_level_)); } -GRPCServer::~GRPCServer() +Server::~Server() { IGNORE_ERR(Stop()); } TRITONSERVER_Error* -GRPCServer::Create( - const std::shared_ptr& server, +Server::Create( + const std::shared_ptr& tritonserver, triton::server::TraceManager* trace_manager, - const std::shared_ptr& shm_manager, int32_t port, - const bool reuse_port, std::string address, bool use_ssl, - const SslOptions& ssl_options, int infer_allocation_pool_size, - grpc_compression_level compression_level, - const KeepAliveOptions& keepalive_options, - std::unique_ptr* grpc_server) + const std::shared_ptr& shm_manager, + const Options& server_options, std::unique_ptr* server) { - const std::string addr = address + ":" + std::to_string(port); - grpc_server->reset(new GRPCServer( - server, trace_manager, shm_manager, addr, reuse_port, use_ssl, - ssl_options, infer_allocation_pool_size, compression_level, - keepalive_options)); + const std::string addr = server_options.socket_.address_ + ":" + + std::to_string(server_options.socket_.port_); + server->reset( + new Server(tritonserver, trace_manager, shm_manager, server_options)); return nullptr; // success } TRITONSERVER_Error* -GRPCServer::Start() +Server::Start() { if (running_) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_ALREADY_EXISTS, "GRPC server is already running."); } - std::shared_ptr credentials; - if (use_ssl_) { - std::string key; - std::string cert; - std::string root; - ReadFile(ssl_options_.server_cert, cert); - ReadFile(ssl_options_.server_key, key); - ReadFile(ssl_options_.root_cert, root); - grpc::SslServerCredentialsOptions::PemKeyCertPair keycert = {key, cert}; - grpc::SslServerCredentialsOptions sslOpts; - sslOpts.pem_root_certs = root; - sslOpts.pem_key_cert_pairs.push_back(keycert); - if (ssl_options_.use_mutual_auth) { - sslOpts.client_certificate_request = - GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY; - } - credentials = grpc::SslServerCredentials(sslOpts); - } else { - credentials = grpc::InsecureServerCredentials(); - } - - int bound_port = 0; - grpc_builder_.AddListeningPort(server_addr_, credentials, &bound_port); - grpc_builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); - grpc_builder_.RegisterService(&service_); - grpc_builder_.RegisterService(&health_service_); - // GRPC KeepAlive Docs: https://grpc.github.io/grpc/cpp/md_doc_keepalive.html - // NOTE: In order to work properly, the client-side settings should - // be in agreement with server-side settings. - grpc_builder_.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, reuse_port_); - grpc_builder_.AddChannelArgument( - GRPC_ARG_KEEPALIVE_TIME_MS, keepalive_options_.keepalive_time_ms); - grpc_builder_.AddChannelArgument( - GRPC_ARG_KEEPALIVE_TIMEOUT_MS, keepalive_options_.keepalive_timeout_ms); - grpc_builder_.AddChannelArgument( - GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, - keepalive_options_.keepalive_permit_without_calls); - grpc_builder_.AddChannelArgument( - GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, - keepalive_options_.http2_max_pings_without_data); - grpc_builder_.AddChannelArgument( - GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS, - keepalive_options_.http2_min_recv_ping_interval_without_data_ms); - grpc_builder_.AddChannelArgument( - GRPC_ARG_HTTP2_MAX_PING_STRIKES, - keepalive_options_.http2_max_ping_strikes); - - LOG_VERBOSE(1) << "=== GRPC KeepAlive Options ==="; - LOG_VERBOSE(1) << "keepalive_time_ms: " - << keepalive_options_.keepalive_time_ms; - LOG_VERBOSE(1) << "keepalive_timeout_ms: " - << keepalive_options_.keepalive_timeout_ms; - LOG_VERBOSE(1) << "keepalive_permit_without_calls: " - << keepalive_options_.keepalive_permit_without_calls; - LOG_VERBOSE(1) << "http2_max_pings_without_data: " - << keepalive_options_.http2_max_pings_without_data; - LOG_VERBOSE(1) - << "http2_min_recv_ping_interval_without_data_ms: " - << keepalive_options_.http2_min_recv_ping_interval_without_data_ms; - LOG_VERBOSE(1) << "http2_max_ping_strikes: " - << keepalive_options_.http2_max_ping_strikes; - LOG_VERBOSE(1) << "=============================="; - - common_cq_ = grpc_builder_.AddCompletionQueue(); - model_infer_cq_ = grpc_builder_.AddCompletionQueue(); - model_stream_infer_cq_ = grpc_builder_.AddCompletionQueue(); - grpc_server_ = grpc_builder_.BuildAndStart(); + server_ = builder_.BuildAndStart(); // Check if binding port failed - if (bound_port == 0) { + if (bound_port_ == 0) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, (std::string("Socket '") + server_addr_ + "' already in use ").c_str()); } - // A common Handler for other non-inference requests - CommonHandler* hcommon = new CommonHandler( - "CommonHandler", server_, shm_manager_, trace_manager_, &service_, - &health_service_, common_cq_.get()); - hcommon->Start(); - common_handler_.reset(hcommon); - - // Handler for model inference requests. - for (int i = 0; i < REGISTER_GRPC_INFER_THREAD_COUNT; ++i) { - ModelInferHandler* hmodelinfer = new ModelInferHandler( - "ModelInferHandler", server_, trace_manager_, shm_manager_, &service_, - model_infer_cq_.get(), - infer_allocation_pool_size_ /* max_state_bucket_count */, - compression_level_); - hmodelinfer->Start(); - model_infer_handlers_.emplace_back(hmodelinfer); + common_handler_->Start(); + for (auto& model_infer_handler : model_infer_handlers_) { + model_infer_handler->Start(); + } + for (auto& model_stream_infer_handler : model_stream_infer_handlers_) { + model_stream_infer_handler->Start(); } - - // Handler for streaming inference requests. Keeps one handler for streaming - // to avoid possible concurrent writes which is not allowed - ModelStreamInferHandler* hmodelstreaminfer = new ModelStreamInferHandler( - "ModelStreamInferHandler", server_, trace_manager_, shm_manager_, - &service_, model_stream_infer_cq_.get(), - infer_allocation_pool_size_ /* max_state_bucket_count */, - compression_level_); - hmodelstreaminfer->Start(); - model_stream_infer_handlers_.emplace_back(hmodelstreaminfer); running_ = true; LOG_INFO << "Started GRPCInferenceService at " << server_addr_; @@ -4870,7 +4945,7 @@ GRPCServer::Start() } TRITONSERVER_Error* -GRPCServer::Stop() +Server::Stop() { if (!running_) { return TRITONSERVER_ErrorNew( @@ -4878,7 +4953,7 @@ GRPCServer::Stop() } // Always shutdown the completion queue after the server. - grpc_server_->Shutdown(); + server_->Shutdown(); common_cq_->Shutdown(); model_infer_cq_->Shutdown(); @@ -4886,17 +4961,16 @@ GRPCServer::Stop() // Must stop all handlers explicitly to wait for all the handler // threads to join since they are referencing completion queue, etc. - dynamic_cast(common_handler_.get())->Stop(); - for (const auto& model_infer_handler : model_infer_handlers_) { - dynamic_cast(model_infer_handler.get())->Stop(); + common_handler_->Stop(); + for (auto& model_infer_handler : model_infer_handlers_) { + model_infer_handler->Stop(); } - for (const auto& model_stream_infer_handler : model_stream_infer_handlers_) { - dynamic_cast(model_stream_infer_handler.get()) - ->Stop(); + for (auto& model_stream_infer_handler : model_stream_infer_handlers_) { + model_stream_infer_handler->Stop(); } running_ = false; return nullptr; // success } -}} // namespace triton::server +}}} // namespace triton::server::grpc diff --git a/src/grpc_server.h b/src/grpc_server.h index 5060da07fb..a27d3d9206 100644 --- a/src/grpc_server.h +++ b/src/grpc_server.h @@ -32,50 +32,58 @@ #include "tracer.h" #include "triton/core/tritonserver.h" -namespace triton { namespace server { +namespace triton { namespace server { namespace grpc { + +struct SocketOptions { + std::string address_{"0.0.0.0"}; + int32_t port_{8001}; + bool reuse_port_{false}; +}; struct SslOptions { - explicit SslOptions() {} + // Whether SSL is used for communication + bool use_ssl_{false}; // File holding PEM-encoded server certificate - std::string server_cert; + std::string server_cert_{""}; // File holding PEM-encoded server key - std::string server_key; + std::string server_key_{""}; // File holding PEM-encoded root certificate - std::string root_cert; + std::string root_cert_{""}; // Whether to use Mutual Authentication - bool use_mutual_auth; + bool use_mutual_auth_{false}; }; // GRPC KeepAlive: https://grpc.github.io/grpc/cpp/md_doc_keepalive.html struct KeepAliveOptions { - explicit KeepAliveOptions() - : keepalive_time_ms(7200000), keepalive_timeout_ms(20000), - keepalive_permit_without_calls(false), http2_max_pings_without_data(2), - http2_min_recv_ping_interval_without_data_ms(300000), - http2_max_ping_strikes(2) - { - } - int keepalive_time_ms; - int keepalive_timeout_ms; - bool keepalive_permit_without_calls; - int http2_max_pings_without_data; - int http2_min_recv_ping_interval_without_data_ms; - int http2_max_ping_strikes; + int keepalive_time_ms_{7200000}; + int keepalive_timeout_ms_{20000}; + bool keepalive_permit_without_calls_{false}; + int http2_max_pings_without_data_{2}; + int http2_min_recv_ping_interval_without_data_ms_{300000}; + int http2_max_ping_strikes_{2}; +}; + +struct Options { + SocketOptions socket_; + SslOptions ssl_; + KeepAliveOptions keep_alive_; + grpc_compression_level infer_compression_level_{GRPC_COMPRESS_LEVEL_NONE}; + // The maximum number of inference request/response objects that + // remain allocated for reuse. As long as the number of in-flight + // requests doesn't exceed this value there will be no + // allocation/deallocation of request/response objects. + int infer_allocation_pool_size_{8}; }; -class GRPCServer { +class Server { public: static TRITONSERVER_Error* Create( - const std::shared_ptr& server, + const std::shared_ptr& tritonserver, triton::server::TraceManager* trace_manager, - const std::shared_ptr& shm_manager, int32_t port, - const bool reuse_port, std::string address, bool use_ssl, - const SslOptions& ssl_options, int infer_allocation_pool_size, - grpc_compression_level compression_level, - const KeepAliveOptions& keepalive_options, - std::unique_ptr* grpc_server); + const std::shared_ptr& shm_manager, + const Options& server_options, std::unique_ptr* server); - ~GRPCServer(); + ~Server(); TRITONSERVER_Error* Start(); TRITONSERVER_Error* Stop(); @@ -84,6 +92,8 @@ class GRPCServer { class HandlerBase { public: virtual ~HandlerBase() = default; + virtual void Start() = 0; + virtual void Stop() = 0; }; class ICallData { @@ -95,42 +105,34 @@ class GRPCServer { }; private: - GRPCServer( - const std::shared_ptr& server, + Server( + const std::shared_ptr& tritonserver, triton::server::TraceManager* trace_manager, const std::shared_ptr& shm_manager, - const std::string& server_addr, const bool reuse_port, bool use_ssl, - const SslOptions& ssl_options, const int infer_allocation_pool_size, - grpc_compression_level compression_level, - const KeepAliveOptions& keepalive_options); + const Options& server_options); - std::shared_ptr server_; + std::shared_ptr tritonserver_; TraceManager* trace_manager_; std::shared_ptr shm_manager_; const std::string server_addr_; - const bool reuse_port_; - const bool use_ssl_; - const SslOptions ssl_options_; - const int infer_allocation_pool_size_; - grpc_compression_level compression_level_; + ::grpc::ServerBuilder builder_; - const KeepAliveOptions keepalive_options_; + inference::GRPCInferenceService::AsyncService service_; + ::grpc::health::v1::Health::AsyncService health_service_; - std::unique_ptr common_cq_; - std::unique_ptr model_infer_cq_; - std::unique_ptr model_stream_infer_cq_; + std::unique_ptr<::grpc::Server> server_; - grpc::ServerBuilder grpc_builder_; - std::unique_ptr grpc_server_; + std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; + std::unique_ptr<::grpc::ServerCompletionQueue> model_infer_cq_; + std::unique_ptr<::grpc::ServerCompletionQueue> model_stream_infer_cq_; std::unique_ptr common_handler_; std::vector> model_infer_handlers_; std::vector> model_stream_infer_handlers_; - inference::GRPCInferenceService::AsyncService service_; - grpc::health::v1::Health::AsyncService health_service_; - bool running_; + int bound_port_{0}; + bool running_{false}; }; -}} // namespace triton::server +}}} // namespace triton::server::grpc diff --git a/src/main.cc b/src/main.cc index da95485a1f..75954ec27a 100644 --- a/src/main.cc +++ b/src/main.cc @@ -118,17 +118,10 @@ std::string vertex_ai_default_model_; #endif // TRITON_ENABLE_VERTEX_AI #ifdef TRITON_ENABLE_GRPC -std::unique_ptr grpc_service_; +// [FIXME] global variable should use different naming convention "g_xxx" +std::unique_ptr grpc_service_; bool allow_grpc_ = true; -int32_t grpc_port_ = 8001; -bool reuse_grpc_port_ = false; -std::string grpc_address_ = "0.0.0.0"; -bool grpc_use_ssl_ = false; -triton::server::SslOptions grpc_ssl_options_; -grpc_compression_level grpc_response_compression_level_ = - GRPC_COMPRESS_LEVEL_NONE; -// KeepAlive defaults: https://grpc.github.io/grpc/cpp/md_doc_keepalive.html -triton::server::KeepAliveOptions grpc_keepalive_options_; +triton::server::grpc::Options grpc_options_; #endif // TRITON_ENABLE_GRPC #ifdef TRITON_ENABLE_METRICS @@ -152,14 +145,6 @@ int32_t trace_count_ = -1; int32_t trace_log_frequency_ = 0; #endif // TRITON_ENABLE_TRACING -#if defined(TRITON_ENABLE_GRPC) -// The maximum number of inference request/response objects that -// remain allocated for reuse. As long as the number of in-flight -// requests doesn't exceed this value there will be no -// allocation/deallocation of request/response objects. -int grpc_infer_allocation_pool_size_ = 8; -#endif // TRITON_ENABLE_GRPC - #if defined(TRITON_ENABLE_HTTP) // The number of threads to initialize for the HTTP front-end. int http_thread_cnt_ = 8; @@ -682,7 +667,9 @@ CheckPortCollision() #endif // TRITON_ENABLE_HTTP #ifdef TRITON_ENABLE_GRPC if (allow_grpc_) { - ports.emplace_back("GRPC", grpc_address_, grpc_port_, false, -1, -1); + ports.emplace_back( + "GRPC", grpc_options_.socket_.address_, grpc_options_.socket_.port_, + false, -1, -1); } #endif // TRITON_ENABLE_GRPC #ifdef TRITON_ENABLE_METRICS @@ -745,16 +732,13 @@ CheckPortCollision() #ifdef TRITON_ENABLE_GRPC TRITONSERVER_Error* StartGrpcService( - std::unique_ptr* service, + std::unique_ptr* service, const std::shared_ptr& server, triton::server::TraceManager* trace_manager, const std::shared_ptr& shm_manager) { - TRITONSERVER_Error* err = triton::server::GRPCServer::Create( - server, trace_manager, shm_manager, grpc_port_, reuse_grpc_port_, - grpc_address_, grpc_use_ssl_, grpc_ssl_options_, - grpc_infer_allocation_pool_size_, grpc_response_compression_level_, - grpc_keepalive_options_, service); + TRITONSERVER_Error* err = triton::server::grpc::Server::Create( + server, trace_manager, shm_manager, grpc_options_, service); if (err == nullptr) { err = (*service)->Start(); } @@ -1429,13 +1413,7 @@ Parse(TRITONSERVER_ServerOptions** server_options, int argc, char** argv) #endif // TRITON_ENABLE_HTTP #if defined(TRITON_ENABLE_GRPC) - int32_t grpc_port = grpc_port_; - bool reuse_grpc_port = reuse_grpc_port_; - std::string grpc_address = grpc_address_; - int32_t grpc_use_ssl = grpc_use_ssl_; - int32_t grpc_infer_allocation_pool_size = grpc_infer_allocation_pool_size_; - grpc_compression_level grpc_response_compression_level = - grpc_response_compression_level_; + triton::server::grpc::Options lgrpc_options; #endif // TRITON_ENABLE_GRPC #if defined(TRITON_ENABLE_SAGEMAKER) @@ -1627,45 +1605,46 @@ Parse(TRITONSERVER_ServerOptions** server_options, int argc, char** argv) allow_grpc_ = ParseBoolOption(optarg); break; case OPTION_GRPC_PORT: - grpc_port = ParseIntOption(optarg); + lgrpc_options.socket_.port_ = ParseIntOption(optarg); break; case OPTION_REUSE_GRPC_PORT: - reuse_grpc_port = ParseIntOption(optarg); + lgrpc_options.socket_.reuse_port_ = ParseIntOption(optarg); break; case OPTION_GRPC_ADDRESS: - grpc_address = optarg; + lgrpc_options.socket_.address_ = optarg; break; case OPTION_GRPC_INFER_ALLOCATION_POOL_SIZE: - grpc_infer_allocation_pool_size = ParseIntOption(optarg); + lgrpc_options.infer_allocation_pool_size_ = ParseIntOption(optarg); break; case OPTION_GRPC_USE_SSL: - grpc_use_ssl = ParseBoolOption(optarg); + lgrpc_options.ssl_.use_ssl_ = ParseBoolOption(optarg); break; case OPTION_GRPC_USE_SSL_MUTUAL: - grpc_ssl_options_.use_mutual_auth = ParseBoolOption(optarg); - grpc_use_ssl = true; + lgrpc_options.ssl_.use_mutual_auth_ = ParseBoolOption(optarg); + // [FIXME] this implies use SSL, take priority over OPTION_GRPC_USE_SSL? + lgrpc_options.ssl_.use_ssl_ = true; break; case OPTION_GRPC_SERVER_CERT: - grpc_ssl_options_.server_cert = optarg; + lgrpc_options.ssl_.server_cert_ = optarg; break; case OPTION_GRPC_SERVER_KEY: - grpc_ssl_options_.server_key = optarg; + lgrpc_options.ssl_.server_key_ = optarg; break; case OPTION_GRPC_ROOT_CERT: - grpc_ssl_options_.root_cert = optarg; + lgrpc_options.ssl_.root_cert_ = optarg; break; case OPTION_GRPC_RESPONSE_COMPRESSION_LEVEL: { std::string mode_str(optarg); std::transform( mode_str.begin(), mode_str.end(), mode_str.begin(), ::tolower); if (mode_str == "none") { - grpc_response_compression_level = GRPC_COMPRESS_LEVEL_NONE; + lgrpc_options.infer_compression_level_ = GRPC_COMPRESS_LEVEL_NONE; } else if (mode_str == "low") { - grpc_response_compression_level = GRPC_COMPRESS_LEVEL_LOW; + lgrpc_options.infer_compression_level_ = GRPC_COMPRESS_LEVEL_LOW; } else if (mode_str == "medium") { - grpc_response_compression_level = GRPC_COMPRESS_LEVEL_MED; + lgrpc_options.infer_compression_level_ = GRPC_COMPRESS_LEVEL_MED; } else if (mode_str == "high") { - grpc_response_compression_level = GRPC_COMPRESS_LEVEL_HIGH; + lgrpc_options.infer_compression_level_ = GRPC_COMPRESS_LEVEL_HIGH; } else { std::cerr << "invalid argument for --grpc_infer_response_compression_level" @@ -1676,25 +1655,28 @@ Parse(TRITONSERVER_ServerOptions** server_options, int argc, char** argv) break; } case OPTION_GRPC_ARG_KEEPALIVE_TIME_MS: - grpc_keepalive_options_.keepalive_time_ms = ParseIntOption(optarg); + lgrpc_options.keep_alive_.keepalive_time_ms_ = ParseIntOption(optarg); break; case OPTION_GRPC_ARG_KEEPALIVE_TIMEOUT_MS: - grpc_keepalive_options_.keepalive_timeout_ms = ParseIntOption(optarg); + lgrpc_options.keep_alive_.keepalive_timeout_ms_ = + ParseIntOption(optarg); break; case OPTION_GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS: - grpc_keepalive_options_.keepalive_permit_without_calls = + lgrpc_options.keep_alive_.keepalive_permit_without_calls_ = ParseBoolOption(optarg); break; case OPTION_GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA: - grpc_keepalive_options_.http2_max_pings_without_data = + lgrpc_options.keep_alive_.http2_max_pings_without_data_ = ParseIntOption(optarg); break; case OPTION_GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS: - grpc_keepalive_options_.http2_min_recv_ping_interval_without_data_ms = + lgrpc_options.keep_alive_ + .http2_min_recv_ping_interval_without_data_ms_ = ParseIntOption(optarg); break; case OPTION_GRPC_ARG_HTTP2_MAX_PING_STRIKES: - grpc_keepalive_options_.http2_max_ping_strikes = ParseIntOption(optarg); + lgrpc_options.keep_alive_.http2_max_ping_strikes_ = + ParseIntOption(optarg); break; #endif // TRITON_ENABLE_GRPC @@ -1912,12 +1894,7 @@ Parse(TRITONSERVER_ServerOptions** server_options, int argc, char** argv) #if defined(TRITON_ENABLE_GRPC) - grpc_port_ = grpc_port; - reuse_grpc_port_ = reuse_grpc_port; - grpc_address_ = grpc_address; - grpc_infer_allocation_pool_size_ = grpc_infer_allocation_pool_size; - grpc_use_ssl_ = grpc_use_ssl; - grpc_response_compression_level_ = grpc_response_compression_level; + grpc_options_ = lgrpc_options; #endif // TRITON_ENABLE_GRPC #ifdef TRITON_ENABLE_METRICS