Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timed wait during UNLOAD while the model becomes UNAVAILABLE in SageMaker #5423

Merged

Conversation

nikhil-sk
Copy link
Contributor

This PR updates the behavior of SageMaker UNLOAD function to use the repository index and verify that the model has been completely UNLOADED to the best of Triton's ability i.e. the model state is UNAVAILABLE and the reason is unloaded.

This (at least in case of python backend) ensures that the function returns only after the associated python workers for the models have been killed.

while (!is_model_unavailable && unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) {
is_model_unavailable = SageMakerMMEUnloadModelCheckStatus(model_name);
sleep(1);
unload_time_in_secs += 1;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this does not account for time taken for this SageMakerMMEUnloadModelCheckStatus(model_name) function call. Is it too small to be neglected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct, in SM-Triton, there's one model repo per model, so the model repository index is expected to contain only one model (except in case of ensembles, where it maybe approx ~5-8), so the call is expected to return within few milliseconds.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, we can measure time for that as well and report a more accurate number. It may not matter much in normal cases but if something weird happens in that function, we will be logging incorrect unload time.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related issue: sleep is not exact, either.

Rather than doing unload_time_in_secs += 1, we should be measuring elapsed time since start of loop after each sleep. This would solve both problems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for comments, I've addressed this by using elapsed time calculation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the model repository index is expected to contain only one model.

Note that index API poll all registered model repos, so the performance is actually proportioned to all "visible" models. Although that is probably still neglectable with respect to the unload time.

@@ -628,6 +628,68 @@ SagemakerAPIServer::SageMakerMMEHandleInfer(
}
}

bool
SagemakerAPIServer::SageMakerMMEUnloadModelCheckStatus(const char* model_name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the name if it just returns a bool. Hard to tell without looking at the code what the bool represents. Does true mean it is unloaded? Does true mean the model is available, so not yet unloaded? I have no idea since I have not read the implementation code yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I have changed the function name, and am no longer using a return type of bool

server_model_index_json.IndexAsObject(id, &index_json);

index_json.MemberAsString("name", &name, &name_len);
index_json.MemberAsString("version", &version, &version_len);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to match on a particular version? Or with sagemaker API, is there only ever one version for a given model name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct, we expect only one version of for a model name

index_json.MemberAsString("name", &name, &name_len);
index_json.MemberAsString("version", &version, &version_len);
index_json.MemberAsString("state", &state, &state_len);
index_json.MemberAsString("reason", &reason, &reason_len);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we guaranteed that these calls will succeed? We're ignoring a return value that could be success or error, according to https://github.com/triton-inference-server/common/blob/main/include/triton/common/triton_json.h#L799-L840

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - I've updated the return type to an error type instead, that will be logged in case of failure to parse.

for Triton unload.*/
while (!is_model_unavailable && unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) {
is_model_unavailable = SageMakerMMEUnloadModelCheckStatus(model_name);
sleep(1);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sleep amount could also be a constant, like UNLOAD_TIMEOUT_SECS_ is. Could also use std::this_thread::sleep_for(std::chrono::milliseconds(UNLOAD_SLEEP_MILLISECONDS_)); instead of sleep to get more granularity, if desired (especially if you frequently need more exactly one sleep but it's usually done much more quickly than one second).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, used the std::this_thread::sleep_for( std::chrono::milliseconds(UNLOAD_SLEEP_MILLISECONDS_)) to sleep for a shorter duration

for Triton unload.*/
while (!is_model_unavailable && unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) {
is_model_unavailable = SageMakerMMEUnloadModelCheckStatus(model_name);
sleep(1);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ok to block the thread with sleep like this? Or do we need to do some kind of shenanigans with event loop to check again after a timeout without blocking the thread?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this seems to be the straight-forward way to block the thread corresponding to the evhtp request thread...since SM unload requests can wait a max of 350 seconds, I believe this is alright. On testing, we don't see other model load/unload requests being affected due to this sleep.

}

std::lock_guard<std::mutex> lock(mutex_);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment, but it can be nice to name a mutex based on what it is intended to protect. models_list_mutex_ would make it clearer what it is needed for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done

@@ -725,8 +734,8 @@ SagemakerAPIServer::SageMakerMMEUnloadModel(

/*Note: Model status check is repo-specific and therefore must be run before
* unregistering the repo, else the model information is lost*/
bool is_model_unavailable = false;
uint32_t unload_time_in_secs = 0;
bool* is_model_unavailable = new bool(false);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to use new here. Instead, this line could just be

bool is_model_unavailable = false;

and below, you would call

unload_err = SageMakerMMECheckUnloadedModelIsUnavailable(
          model_name, &is_model_unavailable);

Note that we passed the address of is_model_unavailable. So, we're passing a pointer to the variable on the stack. This is fine, as long as the pointer does not live past the end of the function that declared the variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done

Comment on lines 765 to 769
unload_time_in_secs =
std::chrono::duration_cast<std::chrono::milliseconds>(
end_time - start_time)
.count() /
1000.0;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, you should use duration_cast to seconds, and not explicitly divide by 1000.0:

unload_time_in_secs =
          std::chrono::duration_cast<std::chrono::seconds>(
              end_time - start_time).count();

Note that in both cases (existing code and my suggested version), it is truncating rather than rounding. For this use case, that is what is desired, since we ultimately want to know once we've passed the timeout threshold.

Yet another way to write this would be to store timeout and elasped time variables as std::chrono::duration variables rather than int32_t variables, which would then automatically do the duration_cast for you when it's not lossy. See https://en.cppreference.com/w/cpp/chrono/duration/duration_cast. Up to personal taste, though using more strong typing can be helpful for avoiding bugs with unit conversions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done - my original intention with /1000 was to report the approx wait down to milliseconds, but that is unnecessary. So I've used cast to seconds. The duration.count() method returns the count as type int64_t. I might consider changing this in a future update, thanks for the suggestion to to use chrono::duration type when declaring the timeout values...

} else {
evhtp_send_reply(req, EVHTP_RES_OK);
}

TRITONSERVER_ErrorDelete(unload_err);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a segfault waiting to happen, since you're returning nullptr at the end of the TRITONSERVER_ServerUnregisterModelRepository.

TRITONSERVER_ErrorDelete does unconditional delete on the argument: https://github.com/triton-inference-server/core/blob/main/src/tritonserver.cc#L712-L717

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think delete on nullptr will not do any harm, but I see the unload_err is reused for multiple API calls and there is chance of memory leak, i.e. unload_err may be set in line 743 and the pointer is overwritten in line 779

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, sorry. delete on nullptr is actually fine. From https://en.cppreference.com/w/cpp/language/delete:

If expression evaluates to a null pointer value, no destructors are called, and the deallocation function may or may not be called (it's unspecified), but the default deallocation functions are guaranteed to do nothing when passed a null pointer.

Copy link
Contributor Author

@nikhil-sk nikhil-sk Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Updated the code to not re-use the unload_err pointer

} else {
evhtp_send_reply(req, EVHTP_RES_OK);
}

TRITONSERVER_ErrorDelete(unload_err);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think delete on nullptr will not do any harm, but I see the unload_err is reused for multiple API calls and there is chance of memory leak, i.e. unload_err may be set in line 743 and the pointer is overwritten in line 779

src/sagemaker_server.cc Show resolved Hide resolved
Copy link
Contributor

@GuanLuo GuanLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comment

src/sagemaker_server.cc Outdated Show resolved Hide resolved
src/sagemaker_server.cc Show resolved Hide resolved
src/sagemaker_server.cc Outdated Show resolved Hide resolved
const char* buffer;
size_t byte_size;
LOG_ERROR
<< "Error when unloading SagMaker Model with dependents for model: "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: SagMaker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed

TRITONSERVER_ServerModelIndex(
server_.get(), ready_flag, &server_model_index_message_);

std::shared_ptr<TRITONSERVER_Message> managed_msg(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally I'd recommend using unique_ptr rather than shared_ptr unless there's a specific reason shared_ptr is needed (and there's not, here).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern was whether this is required due to multiple evhtp threads calling on the unload method. I can check more on this and modify at a later point in another PR, noted it...

@GuanLuo do you have any suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returned message is going to be unique to the caller, so there is no share of ownership here. unique_ptr will be better, except that syntax-wise it is a bit messy to attach custom deleter (in C++11).

json_buffer_.Clear();
request_parameters.Write(&json_buffer_);
TRITONSERVER_ErrorDelete(unload_err);
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it a problem that this block falls through to the rest of the function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this was addressed as part of the previous update, so now the function should return early

Copy link
Contributor

@GuanLuo GuanLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will kick off a CI for sanity check

@GuanLuo GuanLuo merged commit f1aedd4 into triton-inference-server:main Mar 7, 2023
nikhil-sk added a commit to nikhil-sk/server that referenced this pull request Mar 15, 2023
…ageMaker (triton-inference-server#5423)

* Add timed wait during UNLOAD while the model becomes UNAVAILABLE in SageMaker

* Directly use C API to UNLOAD model in SM

* Address comments and bug fixes

* Add logging for model server index

* Change MME model repo

* Address comments and use chrono seconds, don't repeat error assignment

* Address minor comments

* Fix typo in log

* Update minor comment
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

4 participants