Skip to content

Commit

Permalink
Add timed wait during UNLOAD while the model becomes UNAVAILABLE in S…
Browse files Browse the repository at this point in the history
…ageMaker (#5423)

* Add timed wait during UNLOAD while the model becomes UNAVAILABLE in SageMaker

* Directly use C API to UNLOAD model in SM

* Address comments and bug fixes

* Add logging for model server index

* Change MME model repo

* Address comments and use chrono seconds, don't repeat error assignment

* Address minor comments

* Fix typo in log

* Update minor comment
  • Loading branch information
nikhil-sk authored Mar 7, 2023
1 parent b601bd8 commit f1aedd4
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 34 deletions.
10 changes: 9 additions & 1 deletion docker/sagemaker/serve
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,21 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

SAGEMAKER_SINGLE_MODEL_REPO=/opt/ml/model/
SAGEMAKER_MULTI_MODEL_REPO=/opt/ml/models/

# Note: in Triton on SageMaker, each model url is registered as a separate repository
# e.g., /opt/ml/models/<hash>/model. Specifying MME model repo path as /opt/ml/models causes Triton
# to treat it as an additional empty repository and changes
# the state of all models to be UNAVAILABLE in the model repository
# https://github.com/triton-inference-server/core/blob/main/src/model_repository_manager.cc#L914,L922
# On Triton, this path will be a dummy path as it's mandatory to specify a model repo when starting triton
SAGEMAKER_MULTI_MODEL_REPO=/tmp/sagemaker

SAGEMAKER_MODEL_REPO=${SAGEMAKER_SINGLE_MODEL_REPO}
is_mme_mode=false

if [ -n "$SAGEMAKER_MULTI_MODEL" ]; then
if [ "$SAGEMAKER_MULTI_MODEL" == "true" ]; then
mkdir -p ${SAGEMAKER_MULTI_MODEL_REPO}
SAGEMAKER_MODEL_REPO=${SAGEMAKER_MULTI_MODEL_REPO}
is_mme_mode=true
echo "Triton is running in SageMaker MME mode."
Expand Down
177 changes: 146 additions & 31 deletions src/sagemaker_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,69 +628,185 @@ SagemakerAPIServer::SageMakerMMEHandleInfer(
}
}

TRITONSERVER_Error*
SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
const char* model_name, bool* is_model_unavailable)
{
/* Use the RepositoryIndex API to check if the model state has become
UNAVAILABLE i.e. model is no longer in the 'in-the-process-of' being
UNLOADED. Consequently, the reason field should be 'unloaded'.*/
TRITONSERVER_Message* server_model_index_message = nullptr;
uint32_t ready_flag = 0; // value of 1 should be set if only the 'ready'
// models are required from the index. In this case,
// we need all models.
TRITONSERVER_ServerModelIndex(
server_.get(), ready_flag, &server_model_index_message);

std::shared_ptr<TRITONSERVER_Message> shared_ptr_msg(
server_model_index_message,
[](TRITONSERVER_Message* msg) { TRITONSERVER_MessageDelete(msg); });

const char* index_buffer;
size_t index_byte_size;

RETURN_IF_ERR(TRITONSERVER_MessageSerializeToJson(
server_model_index_message, &index_buffer, &index_byte_size));

/* Read into json buffer*/
triton::common::TritonJson::Value server_model_index_json;
server_model_index_json.Parse(index_buffer, index_byte_size);

const char* name;
const char* state;
const char* reason;
const char* version;

size_t name_len;
size_t state_len;
size_t reason_len;
size_t version_len;

for (size_t id = 0; id < server_model_index_json.ArraySize(); ++id) {
triton::common::TritonJson::Value index_json;
server_model_index_json.IndexAsObject(id, &index_json);

RETURN_IF_ERR(index_json.MemberAsString("name", &name, &name_len));

if (std::string(name) == std::string(model_name)) {
RETURN_IF_ERR(index_json.MemberAsString("state", &state, &state_len));

if (std::string(state) == UNLOAD_EXPECTED_STATE_) {
RETURN_IF_ERR(
index_json.MemberAsString("reason", &reason, &reason_len));

if (std::string(reason) == UNLOAD_EXPECTED_REASON_) {
*is_model_unavailable = true;

RETURN_IF_ERR(
index_json.MemberAsString("version", &version, &version_len));

LOG_VERBOSE(1) << "Discovered model: " << name
<< ", version: " << version << " in state: " << state
<< "for the reason: " << reason;

break;
}
}
}
}

return nullptr;
}

void
SagemakerAPIServer::SageMakerMMEUnloadModel(
evhtp_request_t* req, const char* model_name)
{
std::lock_guard<std::mutex> lock(mutex_);

if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) {
LOG_VERBOSE(1) << "Model " << model_name << " is not loaded." << std::endl;
evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/
return;
}

/* Extract targetModel to log the associated archive */
const char* targetModel =
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");

LOG_INFO << "Unloading SageMaker TargetModel: " << targetModel << std::endl;

auto start_time = std::chrono::high_resolution_clock::now();

/* Always unload dependents as well - this is required to unload dependents in
* ensemble */
triton::common::TritonJson::Value request_parameters(
triton::common::TritonJson::ValueType::OBJECT);
triton::common::TritonJson::Value unload_parameter(
request_parameters, triton::common::TritonJson::ValueType::OBJECT);

unload_parameter.AddBool("unload_dependents", true);
request_parameters.Add("parameters", std::move(unload_parameter));
TRITONSERVER_Error* unload_err = nullptr;
unload_err =
TRITONSERVER_ServerUnloadModelAndDependents(server_.get(), model_name);

const char* buffer;
size_t byte_size;
if (unload_err != nullptr) {
EVBufferAddErrorJson(req->buffer_out, unload_err);
evhtp_send_reply(req, EVHTP_RES_BADREQ);

triton::common::TritonJson::WriteBuffer json_buffer_;
json_buffer_.Clear();
request_parameters.Write(&json_buffer_);
LOG_ERROR
<< "Error when unloading SageMaker Model with dependents for model: "
<< model_name << std::endl;

byte_size = json_buffer_.Size();
buffer = json_buffer_.Base();
TRITONSERVER_ErrorDelete(unload_err);
return;
}

evbuffer_add(req->buffer_in, buffer, byte_size);
/*Note: Model status check is repo-specific and therefore must be run before
* unregistering the repo, else the model information is lost*/
bool is_model_unavailable = false;
int64_t unload_time_in_secs = 0;

/* Wait for the model to be completely unloaded. SageMaker waits a maximum
of 360 seconds for the UNLOAD request to timeout. Setting a limit of 350
seconds for Triton unload. This should be run only if above UNLOAD call has
succeeded.*/
if (unload_err == nullptr) {
LOG_VERBOSE(1) << "Using Model Repository Index during UNLOAD to check for "
"status of model: "
<< model_name;
while (is_model_unavailable == false &&
unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) {
LOG_VERBOSE(1) << "In the loop to wait for model to be unavailable";
unload_err = SageMakerMMECheckUnloadedModelIsUnavailable(
model_name, &is_model_unavailable);
if (unload_err != nullptr) {
LOG_ERROR << "Error: Received non-zero exit code on checking for "
"model unavailability. "
<< TRITONSERVER_ErrorMessage(unload_err);
break;
}
std::this_thread::sleep_for(
std::chrono::milliseconds(UNLOAD_SLEEP_MILLISECONDS_));

/* Extract targetModel to log the associated archive */
const char* targetModel =
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");
auto end_time = std::chrono::high_resolution_clock::now();

LOG_INFO << "Unloading SageMaker TargetModel: " << targetModel << std::endl;
unload_time_in_secs = std::chrono::duration_cast<std::chrono::seconds>(
end_time - start_time)
.count();
}
LOG_INFO << "UNLOAD for model " << model_name << " completed in "
<< unload_time_in_secs << " seconds.";
TRITONSERVER_ErrorDelete(unload_err);
}

HandleRepositoryControl(req, "", model_name, "unload");
if ((is_model_unavailable == false) &&
(unload_time_in_secs >= UNLOAD_TIMEOUT_SECS_)) {
LOG_ERROR << "Error: UNLOAD did not complete within expected "
<< UNLOAD_TIMEOUT_SECS_
<< " seconds. This may "
"result in SageMaker UNLOAD timeout.";
}

std::string repo_parent_path = sagemaker_models_list_.at(model_name);

TRITONSERVER_Error* unload_err = TRITONSERVER_ServerUnregisterModelRepository(
TRITONSERVER_Error* unregister_err = nullptr;

unregister_err = TRITONSERVER_ServerUnregisterModelRepository(
server_.get(), repo_parent_path.c_str());

if (unload_err != nullptr) {
if (unregister_err != nullptr) {
EVBufferAddErrorJson(req->buffer_out, unload_err);
evhtp_send_reply(req, EVHTP_RES_BADREQ);
LOG_ERROR << "Unable to unregister model repository for path: "
<< repo_parent_path << std::endl;
TRITONSERVER_ErrorDelete(unload_err);
} else {
evhtp_send_reply(req, EVHTP_RES_OK);
}

TRITONSERVER_ErrorDelete(unregister_err);

std::lock_guard<std::mutex> lock(models_list_mutex_);
sagemaker_models_list_.erase(model_name);
}

void
SagemakerAPIServer::SageMakerMMEGetModel(
evhtp_request_t* req, const char* model_name)
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(models_list_mutex_);

if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) {
evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/
Expand Down Expand Up @@ -721,7 +837,7 @@ SagemakerAPIServer::SageMakerMMEGetModel(
void
SagemakerAPIServer::SageMakerMMEListModel(evhtp_request_t* req)
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(models_list_mutex_);

triton::common::TritonJson::Value sagemaker_list_json(
triton::common::TritonJson::ValueType::OBJECT);
Expand Down Expand Up @@ -866,8 +982,8 @@ SagemakerAPIServer::SageMakerMMELoadModel(
if (config_fstream.is_open()) {
ensemble_config_content << config_fstream.rdbuf();
} else {
continue; // A valid config.pbtxt does not exit at this path, or cannot
// be read
continue; // A valid config.pbtxt does not exist at this path, or
// cannot be read
}

/* Compare matched string with `platform: "ensemble"` or
Expand Down Expand Up @@ -972,7 +1088,7 @@ SagemakerAPIServer::SageMakerMMELoadModel(
} else if (err != nullptr) {
SageMakerMMEHandleOOMError(req, err);
} else {
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(models_list_mutex_);

sagemaker_models_list_.emplace(model_name, repo_parent_path);
evhtp_send_reply(req, EVHTP_RES_OK);
Expand All @@ -995,5 +1111,4 @@ SagemakerAPIServer::SageMakerMMELoadModel(

return;
}

}} // namespace triton::server
14 changes: 12 additions & 2 deletions src/sagemaker_server.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -26,6 +26,7 @@
#pragma once

#include <sys/stat.h>

#include <fstream>
#include <mutex>

Expand Down Expand Up @@ -105,6 +106,9 @@ class SagemakerAPIServer : public HTTPAPIServer {

void SageMakerMMEUnloadModel(evhtp_request_t* req, const char* model_name);

TRITONSERVER_Error* SageMakerMMECheckUnloadedModelIsUnavailable(
const char* model_name, bool* is_model_unavailable);

void SageMakerMMEListModel(evhtp_request_t* req);

void SageMakerMMEGetModel(evhtp_request_t* req, const char* model_name);
Expand Down Expand Up @@ -155,7 +159,13 @@ class SagemakerAPIServer : public HTTPAPIServer {
std::unordered_map<std::string, std::string> sagemaker_models_list_;

/* Mutex to handle concurrent updates */
std::mutex mutex_;
std::mutex models_list_mutex_;

/* Constants */
const uint32_t UNLOAD_TIMEOUT_SECS_ = 350;
const uint32_t UNLOAD_SLEEP_MILLISECONDS_ = 500;
const std::string UNLOAD_EXPECTED_STATE_ = "UNAVAILABLE";
const std::string UNLOAD_EXPECTED_REASON_ = "unloaded";
};

}} // namespace triton::server

0 comments on commit f1aedd4

Please sign in to comment.