Skip to content

fix: MessageQueueShm head index boundary check #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Compare index with size set at MessageQueueu constructor.
  • Loading branch information
yinggeh committed Jul 2, 2025
commit fd994ab9af6a4e84826b5c300e379644c20269f0
27 changes: 17 additions & 10 deletions src/message_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ namespace bi = boost::interprocess;

/// Struct holding the representation of a message queue inside the shared
/// memory.
/// \param size Total size of the message queue.
/// \param size Total size of the message queue. Considered invalid after
/// MessageQueue::LoadFromSharedMemory. Check DLIS-8378 for additional details.
/// \param mutex Handle of the mutex variable protecting index.
/// \param index Used element index.
/// \param sem_empty Semaphore object counting the number of empty buffer slots.
Expand Down Expand Up @@ -118,13 +119,15 @@ class MessageQueue {
int head_idx = Head();
// Additional check to avoid out of bounds read/write. Check DLIS-8378 for
// additional details.
if (head_idx < 0 || static_cast<size_t>(head_idx) >= Size()) {
constexpr const char* error_msg =
"Message queue head index out of bounds";
if (head_idx < 0 || static_cast<uint32_t>(head_idx) >= Size()) {
std::string error_msg =
"internal error: message queue head index out of bounds. Expects "
"positive integer less than the size of message queue " +
std::to_string(Size()) + " but got " + std::to_string(head_idx);
#ifdef TRITON_PB_STUB
LOG_ERROR << error_msg;
#else
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, error_msg);
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, error_msg.c_str());
#endif
return;
}
Expand Down Expand Up @@ -166,13 +169,15 @@ class MessageQueue {
int head_idx = Head();
// Additional check to avoid out of bounds read/write. Check DLIS-8378 for
// additional details.
if (head_idx < 0 || static_cast<size_t>(head_idx) >= Size()) {
constexpr const char* error_msg =
"Message queue head index out of bounds";
if (head_idx < 0 || static_cast<uint32_t>(head_idx) >= Size()) {
std::string error_msg =
"internal error: message queue head index out of bounds. Expects "
"positive integer less than the size of message queue " +
std::to_string(Size()) + " but got " + std::to_string(head_idx);
#ifdef TRITON_PB_STUB
LOG_ERROR << error_msg;
#else
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, error_msg);
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, error_msg.c_str());
#endif
return;
}
Expand Down Expand Up @@ -275,7 +280,7 @@ class MessageQueue {
}

private:
std::size_t& Size() { return mq_shm_ptr_->size; }
uint32_t Size() { return size_; }
const bi::interprocess_mutex& Mutex() { return mq_shm_ptr_->mutex; }
bi::interprocess_mutex* MutexMutable() { return &(mq_shm_ptr_->mutex); }
int& Head() { return mq_shm_ptr_->head; }
Expand Down Expand Up @@ -304,6 +309,7 @@ class MessageQueue {
MessageQueueShm* mq_shm_ptr_;
T* mq_buffer_shm_ptr_;
bi::managed_external_buffer::handle_t mq_handle_;
uint32_t size_;

/// Create/load a Message queue.
/// \param mq_shm Message queue representation in shared memory.
Expand All @@ -315,6 +321,7 @@ class MessageQueue {
mq_buffer_shm_ptr_ = mq_buffer_shm_.data_.get();
mq_shm_ptr_ = mq_shm_.data_.get();
mq_handle_ = mq_shm_.handle_;
size_ = mq_shm_ptr_->size;
}
};
}}} // namespace triton::backend::python
2 changes: 1 addition & 1 deletion src/pb_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ namespace bi = boost::interprocess;
TRITONSERVER_ErrorMessage(pb2_exception.what())); \
} \
} \
while (false)
} while (false)

#define THROW_IF_TRITON_ERROR(X) \
do { \
Expand Down
Loading