Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions lib/async/priority_queue.rb
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def wait_for_value(mutex)
condition.wait(mutex)
return self.value
end

# Invalidate this waiter, making it unusable and detectable as abandoned.
def invalidate!
self.fiber = nil
end

# Check if this waiter has been invalidated.
def valid?
self.fiber&.alive?
end
end

# Create a new priority queue.
Expand All @@ -64,12 +74,9 @@ def close
@mutex.synchronize do
@closed = true

# Signal all waiting fibers with nil, skipping dead ones:
# Signal all waiting fibers with nil, skipping dead/invalid ones:
while waiter = @waiting.pop
if waiter.fiber.alive?
waiter.signal(nil)
end
# Dead waiter discarded, continue to next one.
waiter.signal(nil)
end
end
end
Expand Down Expand Up @@ -105,14 +112,14 @@ def push(item)

@items << item

# Wake up the highest priority waiter if any, skipping dead waiters:
# Wake up the highest priority waiter if any, skipping dead/invalid waiters:
while waiter = @waiting.pop
if waiter.fiber.alive?
if waiter.valid?
value = @items.shift
waiter.signal(value)
break
end
# Dead waiter discarded, try next one.
# Dead/invalid waiter discarded, try next one.
end
end
end
Expand All @@ -133,13 +140,13 @@ def enqueue(*items)

@items.concat(items)

# Wake up waiting fibers in priority order, skipping dead waiters:
# Wake up waiting fibers in priority order, skipping dead/invalid waiters:
while !@items.empty? && (waiter = @waiting.pop)
if waiter.fiber.alive?
if waiter.valid?
value = @items.shift
waiter.signal(value)
end
# Dead waiter discarded, continue to next one.
# Dead/invalid waiter discarded, continue to next one.
end
end
end
Expand Down Expand Up @@ -172,12 +179,16 @@ def dequeue(priority: 0)
@sequence += 1

condition = ConditionVariable.new
waiter = Waiter.new(Fiber.current, priority, sequence, condition, nil)
@waiting.push(waiter)

# Wait for our specific condition variable to be signaled:
# The mutex is released during wait, reacquired after:
return waiter.wait_for_value(@mutex)
begin
waiter = Waiter.new(Fiber.current, priority, sequence, condition, nil)
@waiting.push(waiter)

# Wait for our specific condition variable to be signaled:
return waiter.wait_for_value(@mutex)
ensure
waiter&.invalidate!
end
end
end

Expand Down
78 changes: 78 additions & 0 deletions test/async/priority_queue.rb
Original file line number Diff line number Diff line change
Expand Up @@ -547,4 +547,82 @@
]
end
end

with "waiter invalidation" do
it "should invalidate waiters when tasks are stopped to prevent memory leaks" do
# Start a task that will wait and then be stopped
task = reactor.async do
queue.dequeue(priority: 1)
end

expect(queue.waiting).to be == 1

# Stop the task (simulates exception)
task.stop
task.wait

# Now enqueue an item - should not try to wake the invalid waiter
queue.enqueue("test_item")

# The item should still be available for a new waiter
result = nil
new_task = reactor.async do
result = queue.dequeue
end

new_task.wait
expect(result).to be == "test_item"
end

it "should skip invalid waiters during enqueue" do
received_items = []

# Start multiple waiters
tasks = []
3.times do |i|
tasks << reactor.async do
item = queue.dequeue(priority: i)
received_items << [i, item]
end
end

# Give tasks time to start waiting
expect(queue.waiting).to be == 3

# Stop the middle priority task (priority 1)
tasks[1].stop
tasks[1].wait

# Add items to the queue
queue.enqueue("item1", "item2")

tasks[0].wait
tasks[2].wait

# Should have received items in the valid waiters only
# Invalid waiter (priority 1) should be skipped
expect(received_items.size).to be == 2

# Items should go to highest priority waiters (2, then 0)
priorities_served = received_items.map(&:first).sort.reverse
expect(priorities_served).to be == [2, 0]
end
end

describe Async::PriorityQueue::Waiter do
it "should invalidate correctly" do
condition = ConditionVariable.new
fiber = Fiber.current
waiter = Async::PriorityQueue::Waiter.new(fiber, 1, 1, condition, nil)

expect(waiter).to be(:valid?)
expect(waiter.fiber).to be == fiber
expect(waiter.condition).to be == condition

waiter.invalidate!

expect(waiter).not.to be(:valid?)
expect(waiter.fiber).to be_nil
end
end
end
Loading