Skip to content

fix: MessageQueueShm head index boundary check #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ set(
src/response_sender.h
src/pb_stub.h
src/pb_stub.cc
src/pb_stub_log.h
src/pb_stub_log.cc
src/pb_response_iterator.h
src/pb_response_iterator.cc
src/pb_cancel.cc
Expand Down
47 changes: 42 additions & 5 deletions src/message_queue.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -32,14 +32,19 @@
#include <boost/thread/thread_time.hpp>
#include <cstddef>

#include "pb_utils.h"
#include "shm_manager.h"
#ifdef TRITON_PB_STUB
#include "pb_stub_log.h"
#endif

namespace triton { namespace backend { namespace python {
namespace bi = boost::interprocess;

/// Struct holding the representation of a message queue inside the shared
/// memory.
/// \param size Total size of the message queue.
/// \param size Total size of the message queue. Considered invalid after
/// MessageQueue::LoadFromSharedMemory. Check DLIS-8378 for additional details.
/// \param mutex Handle of the mutex variable protecting index.
/// \param index Used element index.
/// \param sem_empty Semaphore object counting the number of empty buffer slots.
Expand Down Expand Up @@ -110,7 +115,22 @@ class MessageQueue {

{
bi::scoped_lock<bi::interprocess_mutex> lock{*MutexMutable()};
Buffer()[Head()] = message;
int head_idx = Head();
// Additional check to avoid out of bounds read/write. Check DLIS-8378 for
// additional details.
if (head_idx < 0 || static_cast<uint32_t>(head_idx) >= Size()) {
std::string error_msg =
"internal error: message queue head index out of bounds. Expects "
"positive integer less than the size of message queue " +
std::to_string(Size()) + " but got " + std::to_string(head_idx);
#ifdef TRITON_PB_STUB
LOG_ERROR << error_msg;
#else
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, error_msg.c_str());
#endif
return;
}
Buffer()[head_idx] = message;
HeadIncrement();
}
SemFullMutable()->post();
Expand Down Expand Up @@ -145,7 +165,22 @@ class MessageQueue {
}
success = true;

Buffer()[Head()] = message;
int head_idx = Head();
// Additional check to avoid out of bounds read/write. Check DLIS-8378 for
// additional details.
if (head_idx < 0 || static_cast<uint32_t>(head_idx) >= Size()) {
std::string error_msg =
"internal error: message queue head index out of bounds. Expects "
"positive integer less than the size of message queue " +
std::to_string(Size()) + " but got " + std::to_string(head_idx);
#ifdef TRITON_PB_STUB
LOG_ERROR << error_msg;
#else
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, error_msg.c_str());
#endif
return;
}
Buffer()[head_idx] = message;
HeadIncrement();
}
SemFullMutable()->post();
Expand Down Expand Up @@ -244,7 +279,7 @@ class MessageQueue {
}

private:
std::size_t& Size() { return mq_shm_ptr_->size; }
uint32_t Size() { return size_; }
const bi::interprocess_mutex& Mutex() { return mq_shm_ptr_->mutex; }
bi::interprocess_mutex* MutexMutable() { return &(mq_shm_ptr_->mutex); }
int& Head() { return mq_shm_ptr_->head; }
Expand Down Expand Up @@ -273,6 +308,7 @@ class MessageQueue {
MessageQueueShm* mq_shm_ptr_;
T* mq_buffer_shm_ptr_;
bi::managed_external_buffer::handle_t mq_handle_;
uint32_t size_;

/// Create/load a Message queue.
/// \param mq_shm Message queue representation in shared memory.
Expand All @@ -284,6 +320,7 @@ class MessageQueue {
mq_buffer_shm_ptr_ = mq_buffer_shm_.data_.get();
mq_shm_ptr_ = mq_shm_.data_.get();
mq_handle_ = mq_shm_.handle_;
size_ = mq_shm_ptr_->size;
}
};
}}} // namespace triton::backend::python
1 change: 1 addition & 0 deletions src/pb_bls_cancel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "pb_bls_cancel.h"

#include "pb_stub.h"
#include "pb_stub_log.h"

namespace triton { namespace backend { namespace python {

Expand Down
3 changes: 2 additions & 1 deletion src/pb_cancel.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -27,6 +27,7 @@
#include "pb_cancel.h"

#include "pb_stub.h"
#include "pb_stub_log.h"

namespace triton { namespace backend { namespace python {

Expand Down
133 changes: 1 addition & 132 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "pb_preferred_memory.h"
#include "pb_response_iterator.h"
#include "pb_string.h"
#include "pb_stub_log.h"
#include "pb_utils.h"
#include "response_sender.h"
#include "scoped_defer.h"
Expand Down Expand Up @@ -1569,138 +1570,6 @@ Stub::ProcessBLSResponseDecoupled(std::unique_ptr<IPCMessage>& ipc_message)
}
}

std::unique_ptr<Logger> Logger::log_instance_;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved log related code to new files to break circular dependency.


std::unique_ptr<Logger>&
Logger::GetOrCreateInstance()
{
if (Logger::log_instance_.get() == nullptr) {
Logger::log_instance_ = std::make_unique<Logger>();
}

return Logger::log_instance_;
}

// Bound function, called from the python client
void
Logger::Log(const std::string& message, LogLevel level)
{
std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance();
py::object frame = py::module_::import("inspect").attr("currentframe");
py::object caller_frame = frame();
py::object info = py::module_::import("inspect").attr("getframeinfo");
py::object caller_info = info(caller_frame);
py::object filename_python = caller_info.attr("filename");
std::string filename = filename_python.cast<std::string>();
py::object lineno = caller_info.attr("lineno");
uint32_t line = lineno.cast<uint32_t>();

if (!stub->StubToParentServiceActive()) {
Logger::GetOrCreateInstance()->Log(filename, line, level, message);
} else {
std::unique_ptr<PbLog> log_msg(new PbLog(filename, line, message, level));
stub->EnqueueLogRequest(log_msg);
}
}

// Called internally (.e.g. LOG_ERROR << "Error"; )
void
Logger::Log(
const std::string& filename, uint32_t lineno, LogLevel level,
const std::string& message)
{
// If the log monitor service is not active yet, format
// and pass messages to cerr
if (!BackendLoggingActive()) {
std::string path(filename);
size_t pos = path.rfind(std::filesystem::path::preferred_separator);
if (pos != std::string::npos) {
path = path.substr(pos + 1, std::string::npos);
}
#ifdef _WIN32
std::stringstream ss;
SYSTEMTIME system_time;
GetSystemTime(&system_time);
ss << LeadingLogChar(level) << std::setfill('0') << std::setw(2)
<< system_time.wMonth << std::setw(2) << system_time.wDay << ' '
<< std::setw(2) << system_time.wHour << ':' << std::setw(2)
<< system_time.wMinute << ':' << std::setw(2) << system_time.wSecond
<< '.' << std::setw(6) << system_time.wMilliseconds * 1000 << ' '
<< static_cast<uint32_t>(GetCurrentProcessId()) << ' ' << path << ':'
<< lineno << "] ";
#else
std::stringstream ss;
struct timeval tv;
gettimeofday(&tv, NULL);
struct tm tm_time;
gmtime_r(((time_t*)&(tv.tv_sec)), &tm_time);
ss << LeadingLogChar(level) << std::setfill('0') << std::setw(2)
<< (tm_time.tm_mon + 1) << std::setw(2) << tm_time.tm_mday << " "
<< std::setw(2) << tm_time.tm_hour << ':' << std::setw(2)
<< tm_time.tm_min << ':' << std::setw(2) << tm_time.tm_sec << "."
<< std::setw(6) << tv.tv_usec << ' ' << static_cast<uint32_t>(getpid())
<< ' ' << path << ':' << lineno << "] ";
std::cerr << ss.str() << " " << message << std::endl;
#endif
} else {
// Ensure we do not create a stub instance before it has initialized
std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance();
std::unique_ptr<PbLog> log_msg(new PbLog(filename, lineno, message, level));
stub->EnqueueLogRequest(log_msg);
}
}

void
Logger::LogInfo(const std::string& message)
{
Logger::Log(message, LogLevel::kInfo);
}

void
Logger::LogWarn(const std::string& message)
{
Logger::Log(message, LogLevel::kWarning);
}

void
Logger::LogError(const std::string& message)
{
Logger::Log(message, LogLevel::kError);
}

void
Logger::LogVerbose(const std::string& message)
{
Logger::Log(message, LogLevel::kVerbose);
}

const std::string
Logger::LeadingLogChar(const LogLevel& level)
{
switch (level) {
case LogLevel::kWarning:
return "W";
case LogLevel::kError:
return "E";
case LogLevel::kInfo:
case LogLevel::kVerbose:
default:
return "I";
}
}

void
Logger::SetBackendLoggingActive(bool status)
{
backend_logging_active_ = status;
}

bool
Logger::BackendLoggingActive()
{
return backend_logging_active_;
}

PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
{
py::class_<PbError, std::shared_ptr<PbError>> triton_error(
Expand Down
Loading
Loading