Skip to content

Commit 2afc15a

Browse files
committed
buffer & grpcsync: various cleanups and improvements (grpc#6785)
1 parent dd39cdb commit 2afc15a

File tree

3 files changed

+55
-58
lines changed

3 files changed

+55
-58
lines changed

internal/buffer/unbounded.go

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
// Package buffer provides an implementation of an unbounded buffer.
1919
package buffer
2020

21-
import "sync"
21+
import (
22+
"errors"
23+
"sync"
24+
)
2225

2326
// Unbounded is an implementation of an unbounded buffer which does not use
2427
// extra goroutines. This is typically used for passing updates from one entity
@@ -36,6 +39,7 @@ import "sync"
3639
type Unbounded struct {
3740
c chan any
3841
closed bool
42+
closing bool
3943
mu sync.Mutex
4044
backlog []any
4145
}
@@ -45,39 +49,41 @@ func NewUnbounded() *Unbounded {
4549
return &Unbounded{c: make(chan any, 1)}
4650
}
4751

52+
var errBufferClosed = errors.New("Put called on closed buffer.Unbounded")
53+
4854
// Put adds t to the unbounded buffer.
49-
func (b *Unbounded) Put(t any) {
55+
func (b *Unbounded) Put(t any) error {
5056
b.mu.Lock()
5157
defer b.mu.Unlock()
52-
if b.closed {
53-
return
58+
if b.closing {
59+
return errBufferClosed
5460
}
5561
if len(b.backlog) == 0 {
5662
select {
5763
case b.c <- t:
58-
return
64+
return nil
5965
default:
6066
}
6167
}
6268
b.backlog = append(b.backlog, t)
69+
return nil
6370
}
6471

65-
// Load sends the earliest buffered data, if any, onto the read channel
66-
// returned by Get(). Users are expected to call this every time they read a
72+
// Load sends the earliest buffered data, if any, onto the read channel returned
73+
// by Get(). Users are expected to call this every time they successfully read a
6774
// value from the read channel.
6875
func (b *Unbounded) Load() {
6976
b.mu.Lock()
7077
defer b.mu.Unlock()
71-
if b.closed {
72-
return
73-
}
7478
if len(b.backlog) > 0 {
7579
select {
7680
case b.c <- b.backlog[0]:
7781
b.backlog[0] = nil
7882
b.backlog = b.backlog[1:]
7983
default:
8084
}
85+
} else if b.closing && !b.closed {
86+
close(b.c)
8187
}
8288
}
8389

@@ -88,18 +94,23 @@ func (b *Unbounded) Load() {
8894
// send the next buffered value onto the channel if there is any.
8995
//
9096
// If the unbounded buffer is closed, the read channel returned by this method
91-
// is closed.
97+
// is closed after all data is drained.
9298
func (b *Unbounded) Get() <-chan any {
9399
return b.c
94100
}
95101

96-
// Close closes the unbounded buffer.
102+
// Close closes the unbounded buffer. No subsequent data may be Put(), and the
103+
// channel returned from Get() will be closed after all the data is read and
104+
// Load() is called for the final time.
97105
func (b *Unbounded) Close() {
98106
b.mu.Lock()
99107
defer b.mu.Unlock()
100-
if b.closed {
108+
if b.closing {
101109
return
102110
}
103-
b.closed = true
104-
close(b.c)
111+
b.closing = true
112+
if len(b.backlog) == 0 {
113+
b.closed = true
114+
close(b.c)
115+
}
105116
}

internal/buffer/unbounded_test.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func init() {
5252
}
5353

5454
// TestSingleWriter starts one reader and one writer goroutine and makes sure
55-
// that the reader gets all the value added to the buffer by the writer.
55+
// that the reader gets all the values added to the buffer by the writer.
5656
func (s) TestSingleWriter(t *testing.T) {
5757
ub := NewUnbounded()
5858
reads := []int{}
@@ -124,14 +124,25 @@ func (s) TestMultipleWriters(t *testing.T) {
124124
// buffer is closed.
125125
func (s) TestClose(t *testing.T) {
126126
ub := NewUnbounded()
127+
if err := ub.Put(1); err != nil {
128+
t.Fatalf("Unbounded.Put() = %v; want nil", err)
129+
}
127130
ub.Close()
128-
if v, ok := <-ub.Get(); ok {
129-
t.Errorf("Unbounded.Get() = %v, want closed channel", v)
131+
if err := ub.Put(1); err == nil {
132+
t.Fatalf("Unbounded.Put() = <nil>; want non-nil error")
133+
}
134+
if v, ok := <-ub.Get(); !ok {
135+
t.Errorf("Unbounded.Get() = %v, %v, want %v, %v", v, ok, 1, true)
136+
}
137+
if err := ub.Put(1); err == nil {
138+
t.Fatalf("Unbounded.Put() = <nil>; want non-nil error")
130139
}
131-
ub.Put(1)
132140
ub.Load()
133141
if v, ok := <-ub.Get(); ok {
134142
t.Errorf("Unbounded.Get() = %v, want closed channel", v)
135143
}
136-
ub.Close()
144+
if err := ub.Put(1); err == nil {
145+
t.Fatalf("Unbounded.Put() = <nil>; want non-nil error")
146+
}
147+
ub.Close() // ignored
137148
}

internal/grpcsync/callback_serializer.go

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package grpcsync
2020

2121
import (
2222
"context"
23-
"sync"
2423

2524
"google.golang.org/grpc/internal/buffer"
2625
)
@@ -38,8 +37,6 @@ type CallbackSerializer struct {
3837
done chan struct{}
3938

4039
callbacks *buffer.Unbounded
41-
closedMu sync.Mutex
42-
closed bool
4340
}
4441

4542
// NewCallbackSerializer returns a new CallbackSerializer instance. The provided
@@ -65,56 +62,34 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
6562
// callbacks to be executed by the serializer. It is not possible to add
6663
// callbacks once the context passed to NewCallbackSerializer is cancelled.
6764
func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool {
68-
cs.closedMu.Lock()
69-
defer cs.closedMu.Unlock()
70-
71-
if cs.closed {
72-
return false
73-
}
74-
cs.callbacks.Put(f)
75-
return true
65+
return cs.callbacks.Put(f) == nil
7666
}
7767

7868
func (cs *CallbackSerializer) run(ctx context.Context) {
79-
var backlog []func(context.Context)
80-
8169
defer close(cs.done)
70+
71+
// TODO: when Go 1.21 is the oldest supported version, this loop and Close
72+
// can be replaced with:
73+
//
74+
// context.AfterFunc(ctx, cs.callbacks.Close)
8275
for ctx.Err() == nil {
8376
select {
8477
case <-ctx.Done():
8578
// Do nothing here. Next iteration of the for loop will not happen,
8679
// since ctx.Err() would be non-nil.
87-
case callback, ok := <-cs.callbacks.Get():
88-
if !ok {
89-
return
90-
}
80+
case cb := <-cs.callbacks.Get():
9181
cs.callbacks.Load()
92-
callback.(func(ctx context.Context))(ctx)
82+
cb.(func(context.Context))(ctx)
9383
}
9484
}
9585

96-
// Fetch pending callbacks if any, and execute them before returning from
97-
// this method and closing cs.done.
98-
cs.closedMu.Lock()
99-
cs.closed = true
100-
backlog = cs.fetchPendingCallbacks()
86+
// Close the buffer to prevent new callbacks from being added.
10187
cs.callbacks.Close()
102-
cs.closedMu.Unlock()
103-
for _, b := range backlog {
104-
b(ctx)
105-
}
106-
}
10788

108-
func (cs *CallbackSerializer) fetchPendingCallbacks() []func(context.Context) {
109-
var backlog []func(context.Context)
110-
for {
111-
select {
112-
case b := <-cs.callbacks.Get():
113-
backlog = append(backlog, b.(func(context.Context)))
114-
cs.callbacks.Load()
115-
default:
116-
return backlog
117-
}
89+
// Run all pending callbacks.
90+
for cb := range cs.callbacks.Get() {
91+
cs.callbacks.Load()
92+
cb.(func(context.Context))(ctx)
11893
}
11994
}
12095

0 commit comments

Comments
 (0)