@@ -20,7 +20,6 @@ package grpcsync
20
20
21
21
import (
22
22
"context"
23
- "sync"
24
23
25
24
"google.golang.org/grpc/internal/buffer"
26
25
)
@@ -38,8 +37,6 @@ type CallbackSerializer struct {
38
37
done chan struct {}
39
38
40
39
callbacks * buffer.Unbounded
41
- closedMu sync.Mutex
42
- closed bool
43
40
}
44
41
45
42
// NewCallbackSerializer returns a new CallbackSerializer instance. The provided
@@ -65,53 +62,34 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
65
62
// callbacks to be executed by the serializer. It is not possible to add
66
63
// callbacks once the context passed to NewCallbackSerializer is cancelled.
67
64
func (cs * CallbackSerializer ) Schedule (f func (ctx context.Context )) bool {
68
- cs .closedMu .Lock ()
69
- defer cs .closedMu .Unlock ()
70
-
71
- if cs .closed {
72
- return false
73
- }
74
- cs .callbacks .Put (f )
75
- return true
65
+ return cs .callbacks .Put (f ) == nil
76
66
}
77
67
78
68
func (cs * CallbackSerializer ) run (ctx context.Context ) {
79
- var backlog []func (context.Context )
80
-
81
69
defer close (cs .done )
70
+
71
+ // TODO: when Go 1.21 is the oldest supported version, this loop and Close
72
+ // can be replaced with:
73
+ //
74
+ // context.AfterFunc(ctx, cs.callbacks.Close)
82
75
for ctx .Err () == nil {
83
76
select {
84
77
case <- ctx .Done ():
85
78
// Do nothing here. Next iteration of the for loop will not happen,
86
79
// since ctx.Err() would be non-nil.
87
- case callback := <- cs .callbacks .Get ():
80
+ case cb := <- cs .callbacks .Get ():
88
81
cs .callbacks .Load ()
89
- callback .(func (ctx context.Context ))(ctx )
82
+ cb .(func (context.Context ))(ctx )
90
83
}
91
84
}
92
85
93
- // Fetch pending callbacks if any, and execute them before returning from
94
- // this method and closing cs.done.
95
- cs .closedMu .Lock ()
96
- cs .closed = true
97
- backlog = cs .fetchPendingCallbacks ()
86
+ // Close the buffer to prevent new callbacks from being added.
98
87
cs .callbacks .Close ()
99
- cs .closedMu .Unlock ()
100
- for _ , b := range backlog {
101
- b (ctx )
102
- }
103
- }
104
88
105
- func (cs * CallbackSerializer ) fetchPendingCallbacks () []func (context.Context ) {
106
- var backlog []func (context.Context )
107
- for {
108
- select {
109
- case b := <- cs .callbacks .Get ():
110
- backlog = append (backlog , b .(func (context.Context )))
111
- cs .callbacks .Load ()
112
- default :
113
- return backlog
114
- }
89
+ // Run all pending callbacks.
90
+ for cb := range cs .callbacks .Get () {
91
+ cs .callbacks .Load ()
92
+ cb .(func (context.Context ))(ctx )
115
93
}
116
94
}
117
95
0 commit comments