Skip to content

Commit

Permalink
Rework /internal/queue package (#1449)
Browse files Browse the repository at this point in the history
* Rework /internal/queue package

Given our use cases for this package, we don't need methods that don't block
on reads if there's no value to be read. Due to this, I've removed the
ReadOrWait function and did a small redesign of the methods to be more
in line with standard queue method naming.

* Change Read/Write/IsEmpty to Dequeue/Enqueue/Size and remove ReadOrWait.
Now there is no version of Read/Dequeue that doesn't block if the queue
is empty.
* Fix up tests to be in line with this removal of the non-blocking read
and simplified most of the tests.

Signed-off-by: Daniel Canter <dcanter@microsoft.com>
  • Loading branch information
dcantah authored Jul 13, 2022
1 parent 94f78da commit 12d4cd8
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 130 deletions.
2 changes: 1 addition & 1 deletion internal/jobobject/iocp.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func pollIOCP(ctx context.Context, iocpHandle windows.Handle) {
}).Warn("failed to parse job object message")
continue
}
if err := msq.Write(notification); err == queue.ErrQueueClosed {
if err := msq.Enqueue(notification); err == queue.ErrQueueClosed {
// Write will only return an error when the queue is closed.
// The only time a queue would ever be closed is when we call `Close` on
// the job it belongs to which also removes it from the jobMap, so something
Expand Down
2 changes: 1 addition & 1 deletion internal/jobobject/jobobject.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func (job *JobObject) PollNotification() (interface{}, error) {
if job.mq == nil {
return nil, ErrNotRegistered
}
return job.mq.ReadOrWait()
return job.mq.Dequeue()
}

// UpdateProcThreadAttribute updates the passed in ProcThreadAttributeList to contain what is necessary to
Expand Down
61 changes: 21 additions & 40 deletions internal/queue/mq.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ import (
"sync"
)

var (
ErrQueueClosed = errors.New("the queue is closed for reading and writing")
ErrQueueEmpty = errors.New("the queue is empty")
)
var ErrQueueClosed = errors.New("the queue is closed for reading and writing")

// MessageQueue represents a threadsafe message queue to be used to retrieve or
// write messages to.
Expand All @@ -29,8 +26,8 @@ func NewMessageQueue() *MessageQueue {
}
}

// Write writes `msg` to the queue.
func (mq *MessageQueue) Write(msg interface{}) error {
// Enqueue writes `msg` to the queue.
func (mq *MessageQueue) Enqueue(msg interface{}) error {
mq.m.Lock()
defer mq.m.Unlock()

Expand All @@ -43,69 +40,53 @@ func (mq *MessageQueue) Write(msg interface{}) error {
return nil
}

// Read will read a value from the queue if available, otherwise return an error.
func (mq *MessageQueue) Read() (interface{}, error) {
// Dequeue will read a value from the queue and remove it. If the queue
// is empty, this will block until the queue is closed or a value gets enqueued.
func (mq *MessageQueue) Dequeue() (interface{}, error) {
mq.m.Lock()
defer mq.m.Unlock()
if mq.closed {
return nil, ErrQueueClosed
}
if mq.isEmpty() {
return nil, ErrQueueEmpty

for !mq.closed && mq.size() == 0 {
mq.c.Wait()
}
val := mq.messages[0]
mq.messages[0] = nil
mq.messages = mq.messages[1:]
return val, nil
}

// ReadOrWait will read a value from the queue if available, else it will wait for a
// value to become available. This will block forever if nothing gets written or until
// the queue gets closed.
func (mq *MessageQueue) ReadOrWait() (interface{}, error) {
mq.m.Lock()
// We got woken up, check if it's because the queue got closed.
if mq.closed {
mq.m.Unlock()
return nil, ErrQueueClosed
}
if mq.isEmpty() {
for !mq.closed && mq.isEmpty() {
mq.c.Wait()
}
mq.m.Unlock()
return mq.Read()
}

val := mq.messages[0]
mq.messages[0] = nil
mq.messages = mq.messages[1:]
mq.m.Unlock()
return val, nil
}

// IsEmpty returns if the queue is empty
func (mq *MessageQueue) IsEmpty() bool {
// Size returns the size of the queue.
func (mq *MessageQueue) Size() int {
mq.m.RLock()
defer mq.m.RUnlock()
return len(mq.messages) == 0
return mq.size()
}

// Nonexported empty check that doesn't lock so we can call this in Read and Write.
func (mq *MessageQueue) isEmpty() bool {
return len(mq.messages) == 0
// Nonexported size check to check if the queue is empty inside already locked functions.
func (mq *MessageQueue) size() int {
return len(mq.messages)
}

// Close closes the queue for future writes or reads. Any attempts to read or write from the
// queue after close will return ErrQueueClosed. This is safe to call multiple times.
func (mq *MessageQueue) Close() {
mq.m.Lock()
defer mq.m.Unlock()
// Already closed

// Already closed, noop
if mq.closed {
return
}

mq.messages = nil
mq.closed = true
// If there's anybody currently waiting on a value from ReadOrWait, we need to
// If there's anybody currently waiting on a value from Dequeue, we need to
// broadcast so the read(s) can return ErrQueueClosed.
mq.c.Broadcast()
}
158 changes: 70 additions & 88 deletions internal/queue/queue_test.go
Original file line number Diff line number Diff line change
@@ -1,105 +1,59 @@
package queue

import (
"fmt"
"sync"
"testing"
"time"
)

func TestReadWrite(t *testing.T) {
func TestEnqueueDequeue(t *testing.T) {
q := NewMessageQueue()

// Reading from an empty queue should return ErrQueueEmpty
if _, err := q.Read(); err != ErrQueueEmpty {
t.Fatal("expected to receive `ErrQueueEmpty` for reading from empty queue")
}

// Write 1 to the queue and read this later.
if err := q.Write(1); err != nil {
t.Fatal(err)
}

// Read the value. Value will be dequeued.
if msg, err := q.Read(); err != nil || msg != 1 {
t.Fatal(err)
vals := []int{1, 2, 3, 4, 5}
for _, val := range vals {
// Enqueue vals to the queue and read later.
if err := q.Enqueue(val); err != nil {
t.Fatal(err)
}
}

// We just read a value, now try and read again and verify that we get ErrQueueEmpty again.
if _, err := q.Read(); err != ErrQueueEmpty {
t.Fatal(err)
}
for _, val := range vals {
// Dequeueing from an empty queue should block forever until a write occurs.
qVal, err := q.Dequeue()
if err != nil {
t.Fatal(err)
}

// Close the queue and verify that we get an error on write.
q.Close()
if err := q.Write(1); err != ErrQueueClosed {
t.Fatal(err)
if qVal != val {
t.Fatalf("expected %d, got: %d", val, qVal)
}
}
}

func TestReadOrWaitClose(t *testing.T) {
func TestEnqueueDequeueClose(t *testing.T) {
q := NewMessageQueue()

vals := []int{1, 2, 3}
go func() {
_ = q.Write(1)
_ = q.Write(2)
_ = q.Write(3)
time.Sleep(time.Second * 5)
q.Close()
for _, val := range vals {
_ = q.Enqueue(val)
}
}()

time.Sleep(time.Second * 2)

read := 0
for {
if _, err := q.ReadOrWait(); err != nil {
if err == ErrQueueClosed && read == 3 {
break
}
t.Fatal(err)
}
read++
}
}

func TestReadOrWait(t *testing.T) {
q := NewMessageQueue()

go func() {
_ = q.Write(1)
_ = q.Write(2)
_ = q.Write(3)
time.Sleep(time.Second * 5)
_ = q.Write(4)
}()

// Small sleep so that we can give time to ensure a value is written to the queue so we
// can test both states ReadOrWait could be in. These states being there is already a value
// ready for consumption and all we have to do is just read it, or we wait to get signalled of
// an available value.
time.Sleep(time.Second * 1)
timeout := time.After(time.Second * 20)
done := make(chan struct{})
readErr := make(chan error)

go func() {
for {
if msg, err := q.ReadOrWait(); err != nil {
readErr <- err
} else {
if msg == 4 {
done <- struct{}{}
break
}
if _, err := q.Dequeue(); err == nil {
read++
if read == len(vals) {
// Close after we've read all of our values, then on the next
// go around make sure we get ErrClosed()
q.Close()
}
} else if err != ErrQueueClosed {
t.Fatalf("expected to receive ErrQueueClosed, instead got: %s", err)
}
}()

select {
case <-timeout:
t.Fatal("timed out waiting for all queue values to be read")
case <-done:
case err := <-readErr:
t.Fatal(err)
break
}
}

Expand All @@ -109,7 +63,7 @@ func TestMultipleReaders(t *testing.T) {
done := make(chan struct{})
go func() {
for i := 0; i < 50; i++ {
if err := q.Write(1); err != nil {
if err := q.Enqueue(1); err != nil {
errChan <- err
}
}
Expand All @@ -121,7 +75,7 @@ func TestMultipleReaders(t *testing.T) {
// Reader 1
go func() {
for i := 0; i < 25; i++ {
if _, err := q.ReadOrWait(); err != nil {
if _, err := q.Dequeue(); err != nil {
errChan <- err
}
}
Expand All @@ -131,7 +85,7 @@ func TestMultipleReaders(t *testing.T) {
// Reader 2
go func() {
for i := 0; i < 25; i++ {
if _, err := q.ReadOrWait(); err != nil {
if _, err := q.Dequeue(); err != nil {
errChan <- err
}
}
Expand All @@ -143,13 +97,11 @@ func TestMultipleReaders(t *testing.T) {
done <- struct{}{}
}()

timeout := time.After(time.Second * 20)

select {
case err := <-errChan:
t.Fatalf("failed in read or write: %s", err)
case <-done:
case <-timeout:
case <-time.After(time.Second * 20):
t.Fatalf("timeout exceeded waiting for reads to complete")
}
}
Expand All @@ -164,15 +116,15 @@ func TestMultipleReadersClose(t *testing.T) {

// Reader 1
go func() {
if _, err := q.ReadOrWait(); err != ErrQueueClosed {
if _, err := q.Dequeue(); err != ErrQueueClosed {
errChan <- err
}
wg.Done()
}()

// Reader 2
go func() {
if _, err := q.ReadOrWait(); err != ErrQueueClosed {
if _, err := q.Dequeue(); err != ErrQueueClosed {
errChan <- err
}
wg.Done()
Expand All @@ -187,13 +139,43 @@ func TestMultipleReadersClose(t *testing.T) {
// Close the queue and this should signal both readers to return ErrQueueClosed.
q.Close()

timeout := time.After(time.Second * 20)

select {
case err := <-errChan:
t.Fatalf("failed in read or write: %s", err)
case <-done:
case <-timeout:
case <-time.After(time.Second * 20):
t.Fatalf("timeout exceeded waiting for reads to complete")
}
}

func TestDequeueBlock(t *testing.T) {
q := NewMessageQueue()
errChan := make(chan error)
testVal := 1

go func() {
// Intentionally dequeue right away with no elements so we know we actually block on
// no elements.
val, err := q.Dequeue()
if err != nil {
errChan <- err
}
if val != testVal {
errChan <- fmt.Errorf("expected %d, but got %d", testVal, val)
}
close(errChan)
}()

// Ensure dequeue has started
time.Sleep(time.Second * 3)
if err := q.Enqueue(testVal); err != nil {
t.Fatal(err)
}

select {
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
}
}

0 comments on commit 12d4cd8

Please sign in to comment.