Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timed wait during UNLOAD while the model becomes UNAVAILABLE in SageMaker #5423

Merged
10 changes: 9 additions & 1 deletion docker/sagemaker/serve
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,21 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

SAGEMAKER_SINGLE_MODEL_REPO=/opt/ml/model/
SAGEMAKER_MULTI_MODEL_REPO=/opt/ml/models/

# Note: in Triton on SageMaker, each model url is registered as a separate repository
# e.g., /opt/ml/models/<hash>/model. Specifying MME model repo path as /opt/ml/models causes Triton
# to treat it as an additional empty repository and changes
# the state of all models to be UNAVAILABLE in the model repository
# https://github.com/triton-inference-server/core/blob/main/src/model_repository_manager.cc#L914,L922
# On Triton, this path will be a dummy path as it's mandatory to specify a model repo when starting triton
SAGEMAKER_MULTI_MODEL_REPO=/tmp/sagemaker

SAGEMAKER_MODEL_REPO=${SAGEMAKER_SINGLE_MODEL_REPO}
is_mme_mode=false

if [ -n "$SAGEMAKER_MULTI_MODEL" ]; then
if [ "$SAGEMAKER_MULTI_MODEL" == "true" ]; then
mkdir -p ${SAGEMAKER_MULTI_MODEL_REPO}
SAGEMAKER_MODEL_REPO=${SAGEMAKER_MULTI_MODEL_REPO}
is_mme_mode=true
echo "Triton is running in SageMaker MME mode."
Expand Down
176 changes: 145 additions & 31 deletions src/sagemaker_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,69 +628,184 @@ SagemakerAPIServer::SageMakerMMEHandleInfer(
}
}

TRITONSERVER_Error*
SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
const char* model_name, bool* is_model_unavailable)
{
/* Use the RepositoryIndex API to check if the model state has become
UNAVAILABLE i.e. model is no longer in the 'in-the-process-of' being
UNLOADED. Consequently, the reason field should be 'unloaded'.*/
TRITONSERVER_Message* server_model_index_message_ = nullptr;
nikhil-sk marked this conversation as resolved.
Show resolved Hide resolved
uint32_t ready_flag = 0; // value of 1 should be set if only the 'ready'
// models are required from the index. In this case,
// we need all models.
TRITONSERVER_ServerModelIndex(
server_.get(), ready_flag, &server_model_index_message_);

std::shared_ptr<TRITONSERVER_Message> managed_msg(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally I'd recommend using unique_ptr rather than shared_ptr unless there's a specific reason shared_ptr is needed (and there's not, here).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern was whether this is required due to multiple evhtp threads calling on the unload method. I can check more on this and modify at a later point in another PR, noted it...

@GuanLuo do you have any suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returned message is going to be unique to the caller, so there is no share of ownership here. unique_ptr will be better, except that syntax-wise it is a bit messy to attach custom deleter (in C++11).

server_model_index_message_,
[](TRITONSERVER_Message* msg) { TRITONSERVER_MessageDelete(msg); });

const char* index_buffer;
size_t index_byte_size;

RETURN_IF_ERR(TRITONSERVER_MessageSerializeToJson(
server_model_index_message_, &index_buffer, &index_byte_size));

/* Read into json buffer*/
triton::common::TritonJson::Value server_model_index_json;
server_model_index_json.Parse(index_buffer, index_byte_size);

const char* name;
const char* state;
const char* reason;
const char* version;

size_t name_len;
size_t state_len;
size_t reason_len;
size_t version_len;

for (size_t id = 0; id < server_model_index_json.ArraySize(); ++id) {
triton::common::TritonJson::Value index_json;
server_model_index_json.IndexAsObject(id, &index_json);

RETURN_IF_ERR(index_json.MemberAsString("name", &name, &name_len));

if (std::string(name) == std::string(model_name)) {
RETURN_IF_ERR(index_json.MemberAsString("state", &state, &state_len));

if (std::string(state) == UNLOAD_EXPECTED_STATE_) {
nikhil-sk marked this conversation as resolved.
Show resolved Hide resolved
RETURN_IF_ERR(
index_json.MemberAsString("reason", &reason, &reason_len));

if (std::string(reason) == UNLOAD_EXPECTED_REASON_) {
*is_model_unavailable = true;

RETURN_IF_ERR(
index_json.MemberAsString("version", &version, &version_len));

LOG_VERBOSE(1) << "Discovered model: " << name
<< ", version: " << version << " in state: " << state
<< "for the reason: " << reason;

break;
}
}
}
}

return nullptr;
}

void
SagemakerAPIServer::SageMakerMMEUnloadModel(
evhtp_request_t* req, const char* model_name)
{
std::lock_guard<std::mutex> lock(mutex_);

if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) {
LOG_VERBOSE(1) << "Model " << model_name << " is not loaded." << std::endl;
evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/
return;
}

/* Extract targetModel to log the associated archive */
const char* targetModel =
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");

LOG_INFO << "Unloading SageMaker TargetModel: " << targetModel << std::endl;

auto start_time = std::chrono::high_resolution_clock::now();

/* Always unload dependents as well - this is required to unload dependents in
* ensemble */
triton::common::TritonJson::Value request_parameters(
triton::common::TritonJson::ValueType::OBJECT);
triton::common::TritonJson::Value unload_parameter(
request_parameters, triton::common::TritonJson::ValueType::OBJECT);
TRITONSERVER_Error* unload_err = nullptr;
unload_err =
TRITONSERVER_ServerUnloadModelAndDependents(server_.get(), model_name);

unload_parameter.AddBool("unload_dependents", true);
request_parameters.Add("parameters", std::move(unload_parameter));
if (unload_err != nullptr) {
EVBufferAddErrorJson(req->buffer_out, unload_err);
evhtp_send_reply(req, EVHTP_RES_BADREQ);

const char* buffer;
size_t byte_size;
LOG_ERROR
<< "Error when unloading SagMaker Model with dependents for model: "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: SagMaker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed

<< model_name << std::endl;

triton::common::TritonJson::WriteBuffer json_buffer_;
json_buffer_.Clear();
request_parameters.Write(&json_buffer_);
TRITONSERVER_ErrorDelete(unload_err);
nikhil-sk marked this conversation as resolved.
Show resolved Hide resolved
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it a problem that this block falls through to the rest of the function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this was addressed as part of the previous update, so now the function should return early


byte_size = json_buffer_.Size();
buffer = json_buffer_.Base();
/*Note: Model status check is repo-specific and therefore must be run before
* unregistering the repo, else the model information is lost*/
bool is_model_unavailable = false;
int64_t unload_time_in_secs = 0;

/* Wait for the model to be completely unloaded. SageMaker waits a maximum
of 360 seconds for the UNLOAD request to timeout. Setting a limit of 350
seconds for Triton unload. This should be run only if above UNLOAD call has
succeeded.*/
if (unload_err == nullptr) {
LOG_VERBOSE(1) << "Using Model Repository Index during UNLOAD to check for "
"status of model: "
<< model_name;
while (is_model_unavailable == false &&
unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) {
LOG_VERBOSE(1) << "In the loop to wait for model to be unavailable";
unload_err = SageMakerMMECheckUnloadedModelIsUnavailable(
model_name, &is_model_unavailable);
if (unload_err != nullptr) {
LOG_ERROR << "Error: Received non-zero exit code on checking for "
"model unavailability. "
<< TRITONSERVER_ErrorMessage(unload_err);
break;
}
std::this_thread::sleep_for(
std::chrono::milliseconds(UNLOAD_SLEEP_MILLISECONDS_));

evbuffer_add(req->buffer_in, buffer, byte_size);
auto end_time = std::chrono::high_resolution_clock::now();

/* Extract targetModel to log the associated archive */
const char* targetModel =
evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model");

LOG_INFO << "Unloading SageMaker TargetModel: " << targetModel << std::endl;
unload_time_in_secs = std::chrono::duration_cast<std::chrono::seconds>(
end_time - start_time)
.count();
}
LOG_INFO << "UNLOAD for model " << model_name << " completed in "
<< unload_time_in_secs << " seconds.";
TRITONSERVER_ErrorDelete(unload_err);
}

HandleRepositoryControl(req, "", model_name, "unload");
if (is_model_unavailable == false &&
unload_time_in_secs >= UNLOAD_TIMEOUT_SECS_) {
nikhil-sk marked this conversation as resolved.
Show resolved Hide resolved
LOG_ERROR << "Error: UNLOAD did not complete within expected "
<< UNLOAD_TIMEOUT_SECS_
<< " seconds. This may "
"result in SageMaker UNLOAD timeout.";
}

std::string repo_parent_path = sagemaker_models_list_.at(model_name);

TRITONSERVER_Error* unload_err = TRITONSERVER_ServerUnregisterModelRepository(
TRITONSERVER_Error* unregister_err = nullptr;

unregister_err = TRITONSERVER_ServerUnregisterModelRepository(
server_.get(), repo_parent_path.c_str());

if (unload_err != nullptr) {
if (unregister_err != nullptr) {
EVBufferAddErrorJson(req->buffer_out, unload_err);
evhtp_send_reply(req, EVHTP_RES_BADREQ);
LOG_ERROR << "Unable to unregister model repository for path: "
<< repo_parent_path << std::endl;
TRITONSERVER_ErrorDelete(unload_err);
} else {
evhtp_send_reply(req, EVHTP_RES_OK);
}

TRITONSERVER_ErrorDelete(unregister_err);

std::lock_guard<std::mutex> lock(models_list_mutex_);
sagemaker_models_list_.erase(model_name);
}

void
SagemakerAPIServer::SageMakerMMEGetModel(
evhtp_request_t* req, const char* model_name)
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(models_list_mutex_);

if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) {
evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/
Expand Down Expand Up @@ -721,7 +836,7 @@ SagemakerAPIServer::SageMakerMMEGetModel(
void
SagemakerAPIServer::SageMakerMMEListModel(evhtp_request_t* req)
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(models_list_mutex_);

triton::common::TritonJson::Value sagemaker_list_json(
triton::common::TritonJson::ValueType::OBJECT);
Expand Down Expand Up @@ -866,8 +981,8 @@ SagemakerAPIServer::SageMakerMMELoadModel(
if (config_fstream.is_open()) {
ensemble_config_content << config_fstream.rdbuf();
} else {
continue; // A valid config.pbtxt does not exit at this path, or cannot
// be read
continue; // A valid config.pbtxt does not exist at this path, or
// cannot be read
}

/* Compare matched string with `platform: "ensemble"` or
Expand Down Expand Up @@ -972,7 +1087,7 @@ SagemakerAPIServer::SageMakerMMELoadModel(
} else if (err != nullptr) {
SageMakerMMEHandleOOMError(req, err);
} else {
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(models_list_mutex_);

sagemaker_models_list_.emplace(model_name, repo_parent_path);
evhtp_send_reply(req, EVHTP_RES_OK);
Expand All @@ -995,5 +1110,4 @@ SagemakerAPIServer::SageMakerMMELoadModel(

return;
}

}} // namespace triton::server
14 changes: 12 additions & 2 deletions src/sagemaker_server.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -26,6 +26,7 @@
#pragma once

#include <sys/stat.h>

#include <fstream>
#include <mutex>

Expand Down Expand Up @@ -105,6 +106,9 @@ class SagemakerAPIServer : public HTTPAPIServer {

void SageMakerMMEUnloadModel(evhtp_request_t* req, const char* model_name);

TRITONSERVER_Error* SageMakerMMECheckUnloadedModelIsUnavailable(
const char* model_name, bool* is_model_unavailable);

void SageMakerMMEListModel(evhtp_request_t* req);

void SageMakerMMEGetModel(evhtp_request_t* req, const char* model_name);
Expand Down Expand Up @@ -155,7 +159,13 @@ class SagemakerAPIServer : public HTTPAPIServer {
std::unordered_map<std::string, std::string> sagemaker_models_list_;

/* Mutex to handle concurrent updates */
std::mutex mutex_;
std::mutex models_list_mutex_;

/* Constants */
const uint32_t UNLOAD_TIMEOUT_SECS_ = 350;
const uint32_t UNLOAD_SLEEP_MILLISECONDS_ = 500;
const std::string UNLOAD_EXPECTED_STATE_ = "UNAVAILABLE";
const std::string UNLOAD_EXPECTED_REASON_ = "unloaded";
};

}} // namespace triton::server