Skip to content

Commit

Permalink
Connects the rate limiter to the scheduling pipeline (#3388)
Browse files Browse the repository at this point in the history
* Connect the rate limiting pipeline

* Some clean ups in rate limiter logic

* Cleaning extra logic for handling no limiting case
  • Loading branch information
tanmayv25 authored Sep 22, 2021
1 parent 864c73e commit 4a7bd92
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 274 deletions.
6 changes: 1 addition & 5 deletions src/backends/backend/triton_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,12 @@ TritonModel::Create(

Status
TritonModel::AddInstance(
std::unique_ptr<TritonModelInstance>&& instance, const bool passive,
const inference::ModelRateLimiter& rate_limiter_config)
std::unique_ptr<TritonModelInstance>&& instance, const bool passive)
{
if (passive) {
passive_instances_.emplace_back(std::move(instance));
} else {
TritonModelInstance* raw_instance = instance.get();
instances_.emplace_back(std::move(instance));
RETURN_IF_ERROR(server_->GetRateLimiter()->RegisterModelInstance(
raw_instance, rate_limiter_config));
}

return Status::Success;
Expand Down
3 changes: 1 addition & 2 deletions src/backends/backend/triton_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ class TritonModel : public InferenceBackend {
void* State() { return state_; }
void SetState(void* state) { state_ = state; }
Status AddInstance(
std::unique_ptr<TritonModelInstance>&& instance, const bool passive,
const inference::ModelRateLimiter& rate_limiter_config);
std::unique_ptr<TritonModelInstance>&& instance, const bool passive);

private:
DISALLOW_COPY_AND_ASSIGN(TritonModel);
Expand Down
7 changes: 3 additions & 4 deletions src/backends/backend/triton_model_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,6 @@ TritonModelInstance::CreateInstance(
model, name, index, kind, device_id, profile_names, passive, host_policy,
host_policy_message, secondary_devices));

model->Server()->GetRateLimiter()->InitializePayloadQueues(
local_instance.get());
TRITONBACKEND_ModelInstance* triton_instance =
reinterpret_cast<TRITONBACKEND_ModelInstance*>(local_instance.get());

Expand All @@ -280,12 +278,13 @@ TritonModelInstance::CreateInstance(

if (!passive) {
RETURN_IF_ERROR(local_instance->GenerateWarmupData());
RETURN_IF_ERROR(model->Server()->GetRateLimiter()->RegisterModelInstance(
local_instance.get(), rate_limiter_config));
local_instance->SetBackendThread(
kind, device_id, device_blocking, device_to_thread_map);
}

RETURN_IF_ERROR(model->AddInstance(
std::move(local_instance), passive, rate_limiter_config));
RETURN_IF_ERROR(model->AddInstance(std::move(local_instance), passive));

return Status::Success;
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/instance_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ InstanceQueue::Empty()
}

void
InstanceQueue::Enqueue(std::shared_ptr<Payload>& payload)
InstanceQueue::Enqueue(const std::shared_ptr<Payload>& payload)
{
payload_queue_.push_back(payload);
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/instance_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class InstanceQueue {

size_t Size();
bool Empty();
void Enqueue(std::shared_ptr<Payload>& payload);
void Enqueue(const std::shared_ptr<Payload>& payload);
void Dequeue(
std::shared_ptr<Payload>* payload,
std::vector<std::shared_ptr<Payload>>* merged_payloads);
Expand Down
20 changes: 16 additions & 4 deletions src/core/payload.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace nvidia { namespace inferenceserver {
Payload::Payload()
: op_type_(Operation::INFER_RUN),
requests_(std::vector<std::unique_ptr<InferenceRequest>>()),
OnCallback_([]() {}), instance_(nullptr), state_(State::UNINITIALIZED),
queue_start_ns_(0)
OnCallback_([]() {}), OnSecondaryCallback_([]() {}), instance_(nullptr),
state_(State::UNINITIALIZED), queue_start_ns_(0)
{
exec_mu_.reset(new std::mutex());
}
Expand Down Expand Up @@ -73,9 +73,9 @@ Payload::Reset(const Operation op_type, TritonModelInstance* instance)
op_type_ = op_type;
requests_.clear();
OnCallback_ = []() {};
OnSecondaryCallback_ = []() {};
instance_ = instance;
state_ = State::UNINITIALIZED;
OnCallback_ = []() {};
status_.reset(new std::promise<Status>());
queue_start_ns_ = 0;
}
Expand All @@ -86,9 +86,9 @@ Payload::Release()
op_type_ = Operation::INFER_RUN;
requests_.clear();
OnCallback_ = []() {};
OnSecondaryCallback_ = []() {};
instance_ = nullptr;
state_ = State::RELEASED;
OnCallback_ = []() {};
queue_start_ns_ = 0;
}

Expand Down Expand Up @@ -129,6 +129,12 @@ Payload::SetInstance(TritonModelInstance* model_instance)
instance_ = model_instance;
}

void
Payload::SetSecondaryCallback(std::function<void()> OnSecondaryCallback)
{
OnSecondaryCallback_ = OnSecondaryCallback;
}

void
Payload::SetState(Payload::State state)
{
Expand All @@ -147,6 +153,12 @@ Payload::Callback()
OnCallback_();
}

void
Payload::SecondaryCallback()
{
OnSecondaryCallback_();
}

void
Payload::Execute(bool* should_exit)
{
Expand Down
3 changes: 3 additions & 0 deletions src/core/payload.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class Payload {
uint64_t QueueStartNs() { return queue_start_ns_; }
void SetCallback(std::function<void()> OnCallback);
void Callback();
void SetSecondaryCallback(std::function<void()> OnRelease);
void SecondaryCallback();
void SetInstance(TritonModelInstance* model_instance);
TritonModelInstance* GetInstance() { return instance_; }

Expand All @@ -79,6 +81,7 @@ class Payload {
Operation op_type_;
std::vector<std::unique_ptr<InferenceRequest>> requests_;
std::function<void()> OnCallback_;
std::function<void()> OnSecondaryCallback_;
TritonModelInstance* instance_;
State state_;
std::unique_ptr<std::promise<Status>> status_;
Expand Down
Loading

0 comments on commit 4a7bd92

Please sign in to comment.