Skip to content

Commit

Permalink
Address comments and use chrono seconds, don't repeat error assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhil-sk committed Mar 6, 2023
1 parent b07bff9 commit f1daba0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 28 deletions.
62 changes: 35 additions & 27 deletions src/sagemaker_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,6 @@ TRITONSERVER_Error*
SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
const char* model_name, bool* is_model_unavailable)
{
LOG_VERBOSE(1) << "Inside function to check if model is unavailable";

/* Use the RepositoryIndex API to check if the model state has become
UNAVAILABLE i.e. model is no longer in the 'in-the-process-of' being
UNLOADED. Consequently, the reason field should be 'unloaded'.*/
Expand All @@ -651,9 +649,8 @@ SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
const char* index_buffer;
size_t index_byte_size;

/*Handle the return from this function correctly*/
TRITONSERVER_MessageSerializeToJson(
server_model_index_message_, &index_buffer, &index_byte_size);
RETURN_IF_ERR(TRITONSERVER_MessageSerializeToJson(
server_model_index_message_, &index_buffer, &index_byte_size));

/* Read into json buffer*/
triton::common::TritonJson::Value server_model_index_json;
Expand All @@ -662,9 +659,12 @@ SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
const char* name;
const char* state;
const char* reason;
const char* version;

size_t name_len;
size_t state_len;
size_t reason_len;
size_t version_len;

for (size_t id = 0; id < server_model_index_json.ArraySize(); ++id) {
triton::common::TritonJson::Value index_json;
Expand All @@ -678,8 +678,17 @@ SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable(
if (std::string(state) == UNLOAD_EXPECTED_STATE_) {
RETURN_IF_ERR(
index_json.MemberAsString("reason", &reason, &reason_len));

if (std::string(reason) == UNLOAD_EXPECTED_REASON_) {
*is_model_unavailable = true;

RETURN_IF_ERR(
index_json.MemberAsString("version", &version, &version_len));

LOG_VERBOSE(1) << "Discovered model: " << name
<< ", version: " << version << " in state: " << state
<< "for the reason: " << reason;

break;
}
}
Expand Down Expand Up @@ -726,8 +735,8 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(

/*Note: Model status check is repo-specific and therefore must be run before
* unregistering the repo, else the model information is lost*/
bool* is_model_unavailable = new bool(false);
float unload_time_in_secs = 0;
bool is_model_unavailable = false;
int64_t unload_time_in_secs = 0;

/* Wait for the model to be completely unloaded. SageMaker waits a maximum
of 360 seconds for the UNLOAD request to timeout. Setting a limit of 350
Expand All @@ -737,11 +746,11 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(
LOG_VERBOSE(1) << "Using Model Repository Index during UNLOAD to check for "
"status of model: "
<< model_name;
while (*is_model_unavailable == false &&
while (is_model_unavailable == false &&
unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) {
LOG_VERBOSE(1) << "In the loop to wait for model to be unavailable";
unload_err = SageMakerMMECheckUnloadedModelIsUnavailable(
model_name, is_model_unavailable);
model_name, &is_model_unavailable);
if (unload_err != nullptr) {
LOG_ERROR << "Error: Received non-zero exit code on checking for "
"model unavailability. "
Expand All @@ -753,33 +762,31 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(

auto end_time = std::chrono::high_resolution_clock::now();

unload_time_in_secs =
std::chrono::duration_cast<std::chrono::milliseconds>(
end_time - start_time)
.count() /
1000.0;
unload_time_in_secs = std::chrono::duration_cast<std::chrono::seconds>(
end_time - start_time)
.count();
}
LOG_INFO << "UNLOAD for model " << model_name << " took "
LOG_INFO << "UNLOAD for model " << model_name << " completed in "
<< unload_time_in_secs << " seconds.";
TRITONSERVER_ErrorDelete(unload_err);
}

if (*is_model_unavailable == false &&
if (is_model_unavailable == false &&
unload_time_in_secs >= UNLOAD_TIMEOUT_SECS_) {
LOG_ERROR << "Error: UNLOAD did not complete within expected "
<< UNLOAD_TIMEOUT_SECS_
<< " seconds. This may "
"result in SageMaker UNLOAD timeout.";
}

/* Delete allocate mem for bool */
delete is_model_unavailable;

std::string repo_parent_path = sagemaker_models_list_.at(model_name);

unload_err = TRITONSERVER_ServerUnregisterModelRepository(
TRITONSERVER_Error* unregister_err = nullptr;

unregister_err = TRITONSERVER_ServerUnregisterModelRepository(
server_.get(), repo_parent_path.c_str());

if (unload_err != nullptr) {
if (unregister_err != nullptr) {
EVBufferAddErrorJson(req->buffer_out, unload_err);
evhtp_send_reply(req, EVHTP_RES_BADREQ);
LOG_ERROR << "Unable to unregister model repository for path: "
Expand All @@ -788,16 +795,17 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(
evhtp_send_reply(req, EVHTP_RES_OK);
}

TRITONSERVER_ErrorDelete(unload_err);
std::lock_guard<std::mutex> lock(mutex_);
TRITONSERVER_ErrorDelete(unregister_err);

std::lock_guard<std::mutex> lock(models_list_mutex_);
sagemaker_models_list_.erase(model_name);
}

void
SagemakerAPIServer::SageMakerMMEGetModel(
evhtp_request_t* req, const char* model_name)
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(models_list_mutex_);

if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) {
evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/
Expand Down Expand Up @@ -828,7 +836,7 @@ SagemakerAPIServer::SageMakerMMEGetModel(
void
SagemakerAPIServer::SageMakerMMEListModel(evhtp_request_t* req)
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(models_list_mutex_);

triton::common::TritonJson::Value sagemaker_list_json(
triton::common::TritonJson::ValueType::OBJECT);
Expand Down Expand Up @@ -973,7 +981,7 @@ SagemakerAPIServer::SageMakerMMELoadModel(
if (config_fstream.is_open()) {
ensemble_config_content << config_fstream.rdbuf();
} else {
continue; // A valid config.pbtxt does not exit at this path, or
continue; // A valid config.pbtxt does not exist at this path, or
// cannot be read
}

Expand Down Expand Up @@ -1079,7 +1087,7 @@ SagemakerAPIServer::SageMakerMMELoadModel(
} else if (err != nullptr) {
SageMakerMMEHandleOOMError(req, err);
} else {
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(models_list_mutex_);

sagemaker_models_list_.emplace(model_name, repo_parent_path);
evhtp_send_reply(req, EVHTP_RES_OK);
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class SagemakerAPIServer : public HTTPAPIServer {
std::unordered_map<std::string, std::string> sagemaker_models_list_;

/* Mutex to handle concurrent updates */
std::mutex mutex_;
std::mutex models_list_mutex_;

/* Constants */
const uint32_t UNLOAD_TIMEOUT_SECS_ = 350;
Expand Down

0 comments on commit f1daba0

Please sign in to comment.