Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ class ThreadPool {
struct TaskEvent { audioapi::move_only_function<void()> task; };
using Event = std::variant<TaskEvent, StopEvent>;

struct Cntrl {
std::atomic<bool> waitingForTasks{false};
std::atomic<size_t> tasksScheduled{0};
};

using Sender = channels::spsc::Sender<Event, channels::spsc::OverflowStrategy::WAIT_ON_FULL, channels::spsc::WaitStrategy::ATOMIC_WAIT>;
using Receiver = channels::spsc::Receiver<Event, channels::spsc::OverflowStrategy::WAIT_ON_FULL, channels::spsc::WaitStrategy::ATOMIC_WAIT>;
public:
Expand All @@ -38,8 +43,30 @@ class ThreadPool {
workerSenders.emplace_back(std::move(workerSender));
}
loadBalancerThread = std::thread(&ThreadPool::loadBalancerThreadFunc, this, std::move(receiver), std::move(workerSenders));
controlBlock_ = std::make_unique<Cntrl>();
}
ThreadPool(const ThreadPool&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;
ThreadPool(ThreadPool&& other):
loadBalancerThread(std::move(other.loadBalancerThread)),
workers(std::move(other.workers)),
loadBalancerSender(std::move(other.loadBalancerSender)),
controlBlock_(std::move(other.controlBlock_)) {}
ThreadPool& operator=(ThreadPool&& other) {
if (this != &other) {
loadBalancerThread = std::move(other.loadBalancerThread);
workers = std::move(other.workers);
loadBalancerSender = std::move(other.loadBalancerSender);
controlBlock_ = std::move(other.controlBlock_);
other.movedFrom_ = true;
}
return *this;
}

~ThreadPool() {
if (movedFrom_) {
return;
}
loadBalancerSender.send(StopEvent{});
loadBalancerThread.join();
for (auto& worker : workers) {
Expand All @@ -59,16 +86,47 @@ class ThreadPool {
/// @note IMPORTANT: This function is not thread-safe and should be called from a single thread only.
template<typename Func, typename ... Args, typename = std::enable_if_t<std::is_invocable_r_v<void, Func, Args...>>>
void schedule(Func &&task, Args &&... args) noexcept {
auto boundTask = [f = std::forward<Func>(task), ...capturedArgs = std::forward<Args>(args)]() mutable {
controlBlock_->tasksScheduled.fetch_add(1, std::memory_order_release);

/// We know that lifetime of each worker thus spsc thus lambda is strongly bounded by ThreadPool lifetime
/// so we can safely capture control block pointer unsafely here
Cntrl *cntrl = controlBlock_.get();
auto boundTask = [cntrl, f= std::forward<Func>(task), ...capturedArgs = std::forward<Args>(args)]() mutable {
f(std::forward<Args>(capturedArgs)...);
size_t left = cntrl->tasksScheduled.fetch_sub(1, std::memory_order_acq_rel) - 1;
if (left == 0) {
cntrl->waitingForTasks.store(false, std::memory_order_release);
cntrl->waitingForTasks.notify_one();
}
};
loadBalancerSender.send(TaskEvent{audioapi::move_only_function<void()>(std::move(boundTask))});
}

/// @brief Waits for all scheduled tasks to complete
void wait() {
/// This logic might seem incorrect at first glance
/// Main principle for this is that there is only one thread scheduling tasks
/// If he is waiting for the tasks he CANNOT schedule new tasks so we can assume partial
/// synchronization here.
/// We first store true so if any task finishes at this moment he will flip it
/// Then we check if there are any tasks scheduled
/// If there are none we can return immediately
/// If there are some we wait until the last task flips the flag to false
controlBlock_->waitingForTasks.store(true, std::memory_order_release);
if (controlBlock_->tasksScheduled.load(std::memory_order_acquire) == 0) {
controlBlock_->waitingForTasks.store(false, std::memory_order_release);
return;
}
controlBlock_->waitingForTasks.wait(true, std::memory_order_acquire);
return;
}

private:
std::thread loadBalancerThread;
std::vector<std::thread> workers;
Sender loadBalancerSender;
std::unique_ptr<Cntrl> controlBlock_;
bool movedFrom_ = false;

void workerThreadFunc(Receiver &&receiver) {
Receiver localReceiver = std::move(receiver);
Expand Down