Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Messaging: Avoid deadlocks related to 0 receiver behavior #10132

Merged
merged 16 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions go/test/endtoend/messaging/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ var (
time_acked bigint,
message varchar(128),
primary key(id),
index next_idx(priority, time_next desc),
index ack_idx(time_acked)
index poller_idx (time_acked, priority, time_next desc),
) comment 'vitess_message,vt_ack_wait=1,vt_purge_after=3,vt_batch_size=2,vt_cache_size=10,vt_poller_interval=1'`
createUnshardedMessage = `create table unsharded_message(
id bigint,
Expand All @@ -56,8 +55,7 @@ var (
time_acked bigint,
message varchar(128),
primary key(id),
index next_idx(priority, time_next desc),
index ack_idx(time_acked)
index poller_idx (time_acked, priority, time_next desc),
) comment 'vitess_message,vt_ack_wait=1,vt_purge_after=3,vt_batch_size=2,vt_cache_size=10,vt_poller_interval=1'`
userVschema = `{
"sharded": true,
Expand Down
16 changes: 7 additions & 9 deletions go/test/endtoend/messaging/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ var createMessage = `create table vitess_message(
time_acked bigint,
message varchar(128),
primary key(id),
index next_idx(priority, time_next desc),
index ack_idx(time_acked))
index poller_idx (time_acked, priority, time_next desc),
comment 'vitess_message,vt_ack_wait=1,vt_purge_after=3,vt_batch_size=2,vt_cache_size=10,vt_poller_interval=1'`

func TestMessage(t *testing.T) {
Expand Down Expand Up @@ -172,8 +171,7 @@ var createThreeColMessage = `create table vitess_message3(
msg1 varchar(128),
msg2 bigint,
primary key(id),
index next_idx(priority, time_next desc),
index ack_idx(time_acked))
index poller_idx (time_acked, priority, time_next desc),
comment 'vitess_message,vt_ack_wait=1,vt_purge_after=3,vt_batch_size=2,vt_cache_size=10,vt_poller_interval=1'`

func TestThreeColMessage(t *testing.T) {
Expand Down Expand Up @@ -519,18 +517,18 @@ func assertClientCount(t *testing.T, expected int, vttablet *cluster.Vttablet) {
}

func parseDebugVars(t *testing.T, output interface{}, vttablet *cluster.Vttablet) {
debugVarUrl := fmt.Sprintf("http://%s:%d/debug/vars", vttablet.VttabletProcess.TabletHostname, vttablet.HTTPPort)
resp, err := http.Get(debugVarUrl)
debugVarURL := fmt.Sprintf("http://%s:%d/debug/vars", vttablet.VttabletProcess.TabletHostname, vttablet.HTTPPort)
resp, err := http.Get(debugVarURL)
if err != nil {
t.Fatalf("failed to fetch %q: %v", debugVarUrl, err)
t.Fatalf("failed to fetch %q: %v", debugVarURL, err)
}

respByte, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
t.Fatalf("status code %d while fetching %q:\n%s", resp.StatusCode, debugVarUrl, respByte)
t.Fatalf("status code %d while fetching %q:\n%s", resp.StatusCode, debugVarURL, respByte)
}

if err := json.Unmarshal(respByte, output); err != nil {
t.Fatalf("failed to unmarshal JSON from %q: %v", debugVarUrl, err)
t.Fatalf("failed to unmarshal JSON from %q: %v", debugVarURL, err)
}
}
6 changes: 3 additions & 3 deletions go/test/endtoend/vtgate/godriver/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ var (
create table my_message(
time_scheduled bigint,
id bigint,
time_next bigint,
time_next bigint DEFAULT 0,
epoch bigint,
time_created bigint,
time_acked bigint,
message varchar(128),
priority tinyint NOT NULL DEFAULT '0',
priority tinyint NOT NULL DEFAULT 0,
primary key(time_scheduled, id),
unique index id_idx(id),
index next_idx(priority, time_next)
index poller_idx(time_acked, priority, time_next)
) comment 'vitess_message,vt_ack_wait=30,vt_purge_after=86400,vt_batch_size=10,vt_cache_size=10000,vt_poller_interval=30';
`
VSchema = `
Expand Down
82 changes: 46 additions & 36 deletions go/vt/vttablet/tabletserver/messager/message_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,13 @@ type messageManager struct {
isOpen bool
// cond waits on curReceiver == -1 || cache.IsEmpty():
// No current receivers available or cache is empty.
cond sync.Cond
cache *cache
receivers []*receiverWithStatus
curReceiver int
messagesPending bool
cond sync.Cond
cache *cache
receivers []*receiverWithStatus
// Way to track the receiver count in a consistent way w/o locks
receiverCount sync2.AtomicInt64
curReceiver int64
messagesPending sync2.AtomicBool

// streamMu keeps the cache and database consistent with each other.
// Specifically:
Expand Down Expand Up @@ -239,7 +241,7 @@ func newMessageManager(tsv TabletService, vs VStreamer, table *schema.Table, pos
pollerTicks: timer.NewTimer(table.MessageInfo.PollInterval),
purgeTicks: timer.NewTimer(table.MessageInfo.PollInterval),
postponeSema: postponeSema,
messagesPending: true,
messagesPending: sync2.NewAtomicBool(true),
}
mm.cond.L = &mm.mu

Expand All @@ -252,7 +254,9 @@ func newMessageManager(tsv TabletService, vs VStreamer, table *schema.Table, pos
}},
}
mm.readByPriorityAndTimeNext = sqlparser.BuildParsedQuery(
"select priority, time_next, epoch, time_acked, %s from %v where time_next < %a order by priority, time_next desc limit %a",
// There should be a poller_idx defined on (time_acked, priority, time_next desc)
// for this to be as effecient as possible
"select priority, time_next, epoch, time_acked, %s from %v where time_acked is null and time_next < %a order by priority, time_next desc limit %a",
derekperkins marked this conversation as resolved.
Show resolved Hide resolved
columnList, mm.name, ":time_next", ":max")
mm.ackQuery = sqlparser.BuildParsedQuery(
"update %v set time_acked = %a, time_next = null where id in %a and time_acked is null",
Expand Down Expand Up @@ -362,6 +366,7 @@ func (mm *messageManager) Close() {
for _, rcvr := range mm.receivers {
rcvr.receiver.cancel()
}
mm.receiverCount.Set(0)
mm.receivers = nil
MessageStats.Set([]string{mm.name.String(), "ClientCount"}, 0)
log.Infof("messageManager - clearing cache")
Expand Down Expand Up @@ -401,11 +406,12 @@ func (mm *messageManager) Subscribe(ctx context.Context, send func(*sqltypes.Res
withStatus := &receiverWithStatus{
receiver: receiver,
}
if len(mm.receivers) == 0 {
if mm.receiverCount.Get() == 0 {
mm.startVStream()
}
mm.receivers = append(mm.receivers, withStatus)
MessageStats.Set([]string{mm.name.String(), "ClientCount"}, int64(len(mm.receivers)))
mm.receiverCount.Add(1)
MessageStats.Set([]string{mm.name.String(), "ClientCount"}, mm.receiverCount.Get())
if mm.curReceiver == -1 {
mm.rescanReceivers(-1)
}
Expand All @@ -426,16 +432,17 @@ func (mm *messageManager) unsubscribe(receiver *messageReceiver) {
continue
}
// Delete the item at current position.
n := len(mm.receivers)
n := mm.receiverCount.Get()
copy(mm.receivers[i:n-1], mm.receivers[i+1:n])
mm.receivers = mm.receivers[0 : n-1]
MessageStats.Set([]string{mm.name.String(), "ClientCount"}, int64(len(mm.receivers)))
mm.receiverCount.Add(-1)
MessageStats.Set([]string{mm.name.String(), "ClientCount"}, mm.receiverCount.Get())
break
}
// curReceiver is obsolete. Recompute.
mm.rescanReceivers(-1)
// If there are no receivers. Shut down the cache.
if len(mm.receivers) == 0 {
if mm.receiverCount.Get() == 0 {
mm.stopVStream()
mm.cache.Clear()
}
Expand All @@ -447,10 +454,10 @@ func (mm *messageManager) unsubscribe(receiver *messageReceiver) {
// was previously -1, it broadcasts. If none was found,
// curReceiver is set to -1. If there's no starting point,
// it must be specified as -1.
func (mm *messageManager) rescanReceivers(start int) {
func (mm *messageManager) rescanReceivers(start int64) {
cur := start
for range mm.receivers {
cur = (cur + 1) % len(mm.receivers)
cur = (cur + 1) % mm.receiverCount.Get()
if !mm.receivers[cur].busy {
if mm.curReceiver == -1 {
mm.cond.Broadcast()
Expand All @@ -467,9 +474,7 @@ func (mm *messageManager) rescanReceivers(start int) {
// if successful. If the message is already present,
// it still returns true.
func (mm *messageManager) Add(mr *MessageRow) bool {
mm.mu.Lock()
defer mm.mu.Unlock()
if len(mm.receivers) == 0 {
if mm.receiverCount.Get() == 0 {
return false
}
// If cache is empty, we have to broadcast that we're not empty
Expand All @@ -479,7 +484,7 @@ func (mm *messageManager) Add(mr *MessageRow) bool {
}
if !mm.cache.Add(mr) {
// Cache is full. Enter "messagesPending" mode.
mm.messagesPending = true
mm.messagesPending.Set(true)
return false
}
return true
Expand Down Expand Up @@ -510,7 +515,7 @@ func (mm *messageManager) runSend() {

// If cache became empty, there are messages pending, and there are subscribed
// receivers, we have to trigger the poller to fetch more.
if mm.cache.IsEmpty() && mm.messagesPending && len(mm.receivers) != 0 {
if mm.cache.IsEmpty() && mm.messagesPending.Get() && mm.receiverCount.Get() != 0 {
// Do this as a separate goroutine. Otherwise, this could cause
// the following deadlock:
// 1. runSend obtains a lock
Expand Down Expand Up @@ -750,12 +755,12 @@ func (mm *messageManager) processRowEvent(fields []*querypb.Field, rowEvent *bin

func (mm *messageManager) runPoller() {
// Fast-path. Skip all the work.
if mm.receiverCount() == 0 {
if mm.receiverCount.Get() == 0 {
return
}

mm.streamMu.Lock()
defer mm.streamMu.Unlock()
mm.getExclusiveLock()
defer mm.releaseExclusiveLock()

ctx, cancel := context.WithTimeout(tabletenv.LocalContext(), mm.pollerTicks.Interval())
defer func() {
Expand All @@ -768,20 +773,18 @@ func (mm *messageManager) runPoller() {
"time_next": sqltypes.Int64BindVariable(time.Now().UnixNano()),
"max": sqltypes.Int64BindVariable(int64(size)),
}

qr, err := mm.readPending(ctx, bindVars)
if err != nil {
return
}

// Obtain mu lock to verify and preserve that len(receivers) != 0.
mm.mu.Lock()
defer mm.mu.Unlock()
mm.messagesPending = false
mm.messagesPending.Set(false)
if len(qr.Rows) >= size {
// There are probably more messages to be sent.
mm.messagesPending = true
mm.messagesPending.Set(true)
}
if len(mm.receivers) == 0 {
if mm.receiverCount.Get() == 0 {
// Almost never reachable because we just checked this.
return
}
Expand All @@ -798,7 +801,7 @@ func (mm *messageManager) runPoller() {
continue
}
if !mm.cache.Add(mr) {
mm.messagesPending = true
mm.messagesPending.Set(true)
return
}
}
Expand Down Expand Up @@ -880,7 +883,7 @@ func (mm *messageManager) GeneratePurgeQuery(timeCutoff int64) (string, map[stri
}
}

// BuildMessageRow builds a MessageRow for a db row.
// BuildMessageRow builds a MessageRow from a db row.
func BuildMessageRow(row []sqltypes.Value) (*MessageRow, error) {
mr := &MessageRow{Row: row[4:]}
if !row[0].IsNull() {
Expand Down Expand Up @@ -914,12 +917,6 @@ func BuildMessageRow(row []sqltypes.Value) (*MessageRow, error) {
return mr, nil
}

func (mm *messageManager) receiverCount() int {
mm.mu.Lock()
defer mm.mu.Unlock()
return len(mm.receivers)
}

func (mm *messageManager) readPending(ctx context.Context, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
query, err := mm.readByPriorityAndTimeNext.GenerateQuery(bindVars, nil)
if err != nil {
Expand Down Expand Up @@ -949,3 +946,16 @@ func (mm *messageManager) readPending(ctx context.Context, bindVars map[string]*
}
return qr, err
}

// This grants the caller exclusive access to the message service.
// When this is needed for a function, you can use this to
// enforce consistent locking order.
func (mm *messageManager) getExclusiveLock() {
mm.mu.Lock()
mattlord marked this conversation as resolved.
Show resolved Hide resolved
mm.streamMu.Lock()
}

func (mm *messageManager) releaseExclusiveLock() {
mm.streamMu.Unlock()
mm.mu.Unlock()
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ func TestReceiverCancel(t *testing.T) {
for i := 0; i < 10; i++ {
runtime.Gosched()
time.Sleep(10 * time.Millisecond)
if mm.receiverCount() != 0 {
if mm.receiverCount.Get() != 0 {
continue
}
return
}
t.Errorf("receivers were not cleared: %d", len(mm.receivers))
t.Errorf("receivers were not cleared: %d", mm.receiverCount.Get())
}

func TestMessageManagerState(t *testing.T) {
Expand Down Expand Up @@ -281,7 +281,7 @@ func TestMessageManagerSend(t *testing.T) {
runtime.Gosched()
time.Sleep(10 * time.Millisecond)
mm.mu.Lock()
if len(mm.receivers) != 1 {
if mm.receiverCount.Get() != 1 {
mm.mu.Unlock()
continue
}
Expand Down