diff --git a/src/shared_modules/utils/tests/threadEventDispatcher_test.cpp b/src/shared_modules/utils/tests/threadEventDispatcher_test.cpp index 3cb84d56b6a..1975c76bd8f 100644 --- a/src/shared_modules/utils/tests/threadEventDispatcher_test.cpp +++ b/src/shared_modules/utils/tests/threadEventDispatcher_test.cpp @@ -21,7 +21,7 @@ void ThreadEventDispatcherTest::TearDown() { }; constexpr auto BULK_SIZE {50}; -TEST_F(ThreadEventDispatcherTest, Ctor) +TEST_F(ThreadEventDispatcherTest, ConstructorTestSingleThread) { static const std::vector MESSAGES_TO_SEND_LIST {120, 100}; @@ -60,7 +60,47 @@ TEST_F(ThreadEventDispatcherTest, Ctor) } } -TEST_F(ThreadEventDispatcherTest, CtorNoWorker) +TEST_F(ThreadEventDispatcherTest, ConstructorTestMultiThread) +{ + static const std::vector MESSAGES_TO_SEND_LIST {120, 100}; + static const auto NUM_THREADS = 4; + + for (auto MESSAGES_TO_SEND : MESSAGES_TO_SEND_LIST) + { + std::atomic counter {0}; + std::promise promise; + auto index {0}; + + ThreadEventDispatcher&)>, NUM_THREADS> dispatcher( + [&counter, &index, &MESSAGES_TO_SEND, &promise](std::queue& data) + { + counter += data.size(); + while (!data.empty()) + { + auto value = data.front(); + data.pop(); + EXPECT_EQ(std::to_string(index), value); + ++index; + } + + if (counter == MESSAGES_TO_SEND) + { + promise.set_value(); + } + }, + "test.db", + BULK_SIZE); + + for (int i = 0; i < MESSAGES_TO_SEND; ++i) + { + dispatcher.push(std::to_string(i)); + } + promise.get_future().wait_for(std::chrono::seconds(10)); + EXPECT_EQ(MESSAGES_TO_SEND, counter); + } +} + +TEST_F(ThreadEventDispatcherTest, CtorNoWorkerSingleThread) { static const std::vector MESSAGES_TO_SEND_LIST {120, 100}; @@ -101,7 +141,7 @@ TEST_F(ThreadEventDispatcherTest, CtorNoWorker) } } -TEST_F(ThreadEventDispatcherTest, CtorPopFeature) +TEST_F(ThreadEventDispatcherTest, CtorPopFeatureSingleThread) { constexpr auto MESSAGES_TO_SEND {1000}; @@ -144,4 +184,3 @@ TEST_F(ThreadEventDispatcherTest, CtorPopFeature) promise.get_future().wait_for(std::chrono::seconds(10)); EXPECT_EQ(MESSAGES_TO_SEND, counter); } - diff --git a/src/shared_modules/utils/threadEventDispatcher.hpp b/src/shared_modules/utils/threadEventDispatcher.hpp index 892ea2a51ed..cdfaaa7c494 100644 --- a/src/shared_modules/utils/threadEventDispatcher.hpp +++ b/src/shared_modules/utils/threadEventDispatcher.hpp @@ -25,6 +25,7 @@ template, typename TSafeQueueType = Utils::TSafeQueue>> class TThreadEventDispatcher @@ -39,7 +40,23 @@ class TThreadEventDispatcher , m_bulkSize {bulkSize} , m_queue {std::make_unique(TQueueType(dbPath))} { - m_thread = std::thread {&TThreadEventDispatcher::dispatch, this}; + m_threads.reserve(TNumberOfThreads); + + if constexpr (TNumberOfThreads == 1) + { + m_threads.push_back(std::thread { + &TThreadEventDispatcher::dispatch, this}); + } + else + { + static_assert(isSameType, "T and U are not the same type"); + for (unsigned int i = 0; i < TNumberOfThreads; ++i) + { + m_threads.push_back(std::thread { + &TThreadEventDispatcher::dispatch, + this}); + } + } } explicit TThreadEventDispatcher(const std::string& dbPath, @@ -61,12 +78,27 @@ class TThreadEventDispatcher void startWorker(Functor functor) { m_functor = std::move(functor); - m_thread = std::thread {&TThreadEventDispatcher::dispatch, this}; + m_threads.reserve(TNumberOfThreads); + + if constexpr (TNumberOfThreads == 1) + { + m_threads.push_back(std::thread { + &TThreadEventDispatcher::dispatch, this}); + } + else + { + for (unsigned int i = 0; i < TNumberOfThreads; ++i) + { + m_threads.push_back(std::thread { + &TThreadEventDispatcher::dispatch, + this}); + } + } } void push(const T& value) { - if constexpr (!std::is_same_v>, TSafeQueueType>) + if constexpr (!isTSafeMultiQueue) { if (m_running && (UNLIMITED_QUEUE_SIZE == m_maxQueueSize || m_queue->size() < m_maxQueueSize)) { @@ -76,14 +108,13 @@ class TThreadEventDispatcher else { // static assert to avoid compilation - static_assert(std::is_same_v>, TSafeQueueType>, - "This method is not supported for this queue type"); + static_assert(isTSafeMultiQueue, "This method is not supported for this queue type"); } } void push(std::string_view prefix, const T& value) { - if constexpr (std::is_same_v>, TSafeQueueType>) + if constexpr (isTSafeMultiQueue) { if (m_running && (UNLIMITED_QUEUE_SIZE == m_maxQueueSize || m_queue->size(prefix) < m_maxQueueSize)) { @@ -93,22 +124,20 @@ class TThreadEventDispatcher else { // static assert to avoid compilation - static_assert(std::is_same_v>, TSafeQueueType>, - "This method is not supported for this queue type"); + static_assert(isTSafeMultiQueue, "This method is not supported for this queue type"); } } void clear(std::string_view prefix = "") { - if constexpr (std::is_same_v>, TSafeQueueType>) + if constexpr (isTSafeMultiQueue) { m_queue->clear(prefix); } else { // static assert to avoid compilation - static_assert(std::is_same_v>, TSafeQueueType>, - "This method is not supported for this queue type"); + static_assert(isTSafeMultiQueue, "This method is not supported for this queue type"); } } @@ -116,7 +145,7 @@ class TThreadEventDispatcher { m_running = false; m_queue->cancel(); - joinThread(); + joinThreads(); } bool cancelled() const @@ -126,106 +155,214 @@ class TThreadEventDispatcher size_t size() const { - if constexpr (!std::is_same_v>, TSafeQueueType>) + if constexpr (!isTSafeMultiQueue) { return m_queue->size(); } else { - static_assert(std::is_same_v>, TSafeQueueType>, - "This method is not supported for this queue type"); + static_assert(isTSafeMultiQueue, "This method is not supported for this queue type"); } } size_t size(std::string_view prefix) const { - if constexpr (std::is_same_v>, TSafeQueueType>) + if constexpr (isTSafeMultiQueue) { return m_queue->size(prefix); } else { // static assert to avoid compilation - static_assert(std::is_same_v>, TSafeQueueType>, - "This method is not supported for this queue type"); + static_assert(isTSafeMultiQueue, "This method is not supported for this queue type"); } } void postpone(std::string_view prefix, const std::chrono::seconds& time) noexcept { - if constexpr (std::is_same_v>, TSafeQueueType>) + if constexpr (isTSafeMultiQueue) { m_queue->postpone(prefix, time); } else { // static assert to avoid compilation - static_assert(std::is_same_v>, TSafeQueueType>, - "This method is not supported for this queue type"); + static_assert(isTSafeMultiQueue, "This method is not supported for this queue type"); } } private: + /** + * @brief Check if the queue type is a `TSafeMultiQueue`. + * + */ + static constexpr bool isTSafeMultiQueue = + std::is_same_v>, TSafeQueueType>; + + /** + * @brief Check if the queue type is a `TSafeQueue`. + * + */ + static constexpr bool isTSafeQueue = std::is_same_v>, TSafeQueueType>; + + /** + * @brief Check if the queue value are the same type. This is crucial for the `multiAndUnordered` method. + * + */ + static constexpr bool isSameType = std::is_same_v; + + /** + * @brief Dispatch function to handle queue processing based on the number of threads. + * + * This function enters a loop that runs while the dispatcher is active. Depending on the number of threads, + * it either processes the queue in a single-threaded, ordered manner or in a multi-threaded, unordered manner. + * + * - In the single-threaded case, it uses the `singleAndOrdered` method. + * - In the multi-threaded case, it uses the `multiAndUnordered` method. + */ void dispatch() { + // Loop while the dispatcher is running while (m_running) { - try + // If only one thread is used, process the queue in a single-threaded, ordered manner + if constexpr (TNumberOfThreads == 1) { - if constexpr (std::is_same_v>, TSafeQueueType>) + singleAndOrdered(); + } + // If multiple threads are used, process the queue in a multi-threaded, unordered manner + else + { + multiAndUnordered(); + } + } + } + + /** + * @brief Processes the queue in a single-threaded, ordered manner. + * + * This function checks the type of the queue and processes it accordingly. It supports `RocksDBQueue` and + * `RocksDBQueueCF` queue types. In case of an exception, it logs the error. + */ + void singleAndOrdered() + { + try + { + if constexpr (isTSafeQueue) + { + std::queue data = m_queue->getBulk(m_bulkSize); + const auto size = data.size(); + + if (!data.empty()) { - std::queue data = m_queue->getBulk(m_bulkSize); - const auto size = data.size(); - - if (!data.empty()) - { - m_functor(data); - m_queue->popBulk(size); - } + m_functor(data); + m_queue->popBulk(size); } - else if constexpr (std::is_same_v>, TSafeQueueType>) + } + else if constexpr (isTSafeMultiQueue) + { + std::pair data = m_queue->front(); + if (!data.second.empty()) { - std::pair data = m_queue->front(); - if (!data.second.empty()) - { - m_functor(data.first); - m_queue->pop(data.second); - } + m_functor(data.first); + m_queue->pop(data.second); } - else + } + else + { + // static assert to avoid compilation for unsupported queue types + static_assert(isTSafeQueue || isTSafeMultiQueue, "This method is not supported for this queue type"); + } + } + catch (const std::exception& ex) + { + // Log the error if an exception occurs + std::cerr << "Dispatch handler error: " << ex.what() << "\n"; + } + } + + /** + * @brief Processes the queue in a multi-threaded, unordered manner. + * + * This function handles queue elements based on the type of queue. It supports `RocksDBQueue` and `RocksDBQueueCF` + * queue types. In case of an exception, it catches the exception and re-inserts the elements back into the queue. + */ + void multiAndUnordered() + { + static_assert(isSameType, "T and U are not the same type"); + std::queue data; // Declare data outside the try block to ensure scope in catch block + try + { + if constexpr (isTSafeQueue) + { + data = m_queue->getBulkAndPop(m_bulkSize); + const auto size = data.size(); + + if (!data.empty()) { - // static assert to avoid compilation - static_assert( - std::is_same_v>, TSafeQueueType> || - std::is_same_v>, TSafeQueueType>, - "This method is not supported for this queue type"); + m_functor(data); } } - catch (const std::exception& ex) + else if constexpr (isTSafeMultiQueue) { - std::cerr << "Dispatch handler error, " << ex.what() << "\n"; + auto dataPair = m_queue->front(); + if (!dataPair.second.empty()) + { + m_functor(dataPair.first); + m_queue->pop(dataPair.second); + } + } + else + { + // static assert to avoid compilation for unsupported queue types + static_assert(isTSafeQueue || isTSafeMultiQueue, "This method is not supported for this queue type"); + } + } + catch (const std::exception& ex) + { + // Reinsert elements in the queue in case of exception on the functor. + if constexpr (isTSafeQueue) + { + while (!data.empty()) + { + m_queue->push(data.front()); + data.pop(); + } + std::cerr << "Dispatch handler error. Elements reinserted: " << ex.what() << "\n"; + } + else if constexpr (isTSafeMultiQueue) + { + while (!data.empty()) + { + m_queue->push(data.front()); + data.pop(); + } + std::cerr << "Dispatch handler error. Elements reinserted: " << ex.what() << "\n"; } } } - void joinThread() + void joinThreads() { - if (m_thread.joinable()) + for (auto& thread : m_threads) { - m_thread.join(); + if (thread.joinable()) + { + thread.join(); + } } } Functor m_functor; std::unique_ptr m_queue; - std::thread m_thread; + std::vector m_threads; std::atomic_bool m_running = true; const size_t m_maxQueueSize; const uint64_t m_bulkSize; }; -template -using ThreadEventDispatcher = TThreadEventDispatcher; +template +using ThreadEventDispatcher = TThreadEventDispatcher; #endif // _THREAD_EVENT_DISPATCHER_HPP diff --git a/src/shared_modules/utils/threadSafeQueue.h b/src/shared_modules/utils/threadSafeQueue.h index 4d354d8622f..724a0027a61 100644 --- a/src/shared_modules/utils/threadSafeQueue.h +++ b/src/shared_modules/utils/threadSafeQueue.h @@ -139,6 +139,43 @@ namespace Utils } } + std::queue getBulkAndPop(const uint64_t elementsQuantity, + const std::chrono::seconds& timeout = std::chrono::seconds(5)) + { + std::unique_lock lock {m_mutex}; + std::queue bulkQueue; + + // If we have less elements than requested, wait for more elements to be pushed. + // coverity[missing_lock] + if (m_queue.size() < elementsQuantity) + { + m_cv.wait_for(lock, + timeout, + [this, elementsQuantity]() + { + // coverity[missing_lock] + return m_canceled.load() || m_queue.size() >= elementsQuantity; + }); + } + + // If the queue is not canceled, get the elements. + if (!m_canceled) + { + for (auto i = 0; i < elementsQuantity && i < m_queue.size(); ++i) + { + bulkQueue.push(std::move(m_queue.at(i))); + } + } + + // Pop the elements from the queue after getting them. + for (auto i = 0; i < elementsQuantity && !m_queue.empty(); ++i) + { + m_queue.pop(); + } + + return bulkQueue; + } + bool empty() const { std::lock_guard lock {m_mutex}; diff --git a/src/wazuh_modules/vulnerability_scanner/src/scanOrchestrator/scanOrchestrator.hpp b/src/wazuh_modules/vulnerability_scanner/src/scanOrchestrator/scanOrchestrator.hpp index 83c2f5f2617..a8e7ae1583a 100644 --- a/src/wazuh_modules/vulnerability_scanner/src/scanOrchestrator/scanOrchestrator.hpp +++ b/src/wazuh_modules/vulnerability_scanner/src/scanOrchestrator/scanOrchestrator.hpp @@ -27,15 +27,18 @@ constexpr auto INVENTORY_DB_PATH = "queue/vd/inventory"; constexpr auto DELAYED_EVENTS_BULK_SIZE {1}; constexpr auto DELAYED_QUEUE_PATH = "queue/vd/delayed"; constexpr auto DELAYED_POSTPONE_SECONDS {60}; +constexpr auto MAX_THREADS {1}; using EventDispatcher = TThreadEventDispatcher&)>>; + std::function&)>, + MAX_THREADS>; using EventDelayedDispatcher = TThreadEventDispatcher, + MAX_THREADS, RocksDBQueueCF, Utils::TSafeMultiQueue