diff --git a/br/pkg/streamhelper/BUILD.bazel b/br/pkg/streamhelper/BUILD.bazel index cb8442cb618dc..19926e96aa0df 100644 --- a/br/pkg/streamhelper/BUILD.bazel +++ b/br/pkg/streamhelper/BUILD.bazel @@ -68,7 +68,7 @@ go_test( ], flaky = True, race = "on", - shard_count = 19, + shard_count = 20, deps = [ ":streamhelper", "//br/pkg/errors", diff --git a/br/pkg/streamhelper/advancer.go b/br/pkg/streamhelper/advancer.go index 6fca0f8686c16..df36b3773a15b 100644 --- a/br/pkg/streamhelper/advancer.go +++ b/br/pkg/streamhelper/advancer.go @@ -443,7 +443,7 @@ func (c *CheckpointAdvancer) stopSubscriber() { c.subscriber = nil } -func (c *CheckpointAdvancer) spawnSubscriptionHandler(ctx context.Context) { +func (c *CheckpointAdvancer) SpawnSubscriptionHandler(ctx context.Context) { c.subscriberMu.Lock() defer c.subscriberMu.Unlock() c.subscriber = NewSubscriber(c.env, c.env, WithMasterContext(ctx)) @@ -470,9 +470,12 @@ func (c *CheckpointAdvancer) spawnSubscriptionHandler(ctx context.Context) { } func (c *CheckpointAdvancer) subscribeTick(ctx context.Context) error { + c.subscriberMu.Lock() + defer c.subscriberMu.Unlock() if c.subscriber == nil { return nil } + failpoint.Inject("get_subscriber", nil) if err := c.subscriber.UpdateStoreTopology(ctx); err != nil { log.Warn("[log backup advancer] Error when updating store topology.", logutil.ShortError(err)) } diff --git a/br/pkg/streamhelper/advancer_daemon.go b/br/pkg/streamhelper/advancer_daemon.go index 4e3b68eb3fbf5..5bac78fe83604 100644 --- a/br/pkg/streamhelper/advancer_daemon.go +++ b/br/pkg/streamhelper/advancer_daemon.go @@ -34,10 +34,10 @@ func (c *CheckpointAdvancer) OnStart(ctx context.Context) { // OnBecomeOwner implements daemon.Interface. If the tidb-server become owner, this function will be called. func (c *CheckpointAdvancer) OnBecomeOwner(ctx context.Context) { metrics.AdvancerOwner.Set(1.0) - c.spawnSubscriptionHandler(ctx) + c.SpawnSubscriptionHandler(ctx) go func() { <-ctx.Done() - c.onStop() + c.OnStop() }() } @@ -46,7 +46,7 @@ func (c *CheckpointAdvancer) Name() string { return "LogBackup::Advancer" } -func (c *CheckpointAdvancer) onStop() { +func (c *CheckpointAdvancer) OnStop() { metrics.AdvancerOwner.Set(0.0) c.stopSubscriber() } diff --git a/br/pkg/streamhelper/advancer_test.go b/br/pkg/streamhelper/advancer_test.go index 05e0578e0721b..3d6fdcee79cba 100644 --- a/br/pkg/streamhelper/advancer_test.go +++ b/br/pkg/streamhelper/advancer_test.go @@ -361,3 +361,38 @@ func TestResolveLock(t *testing.T) { require.Len(t, r.FailureSubRanges, 0) require.Equal(t, r.Checkpoint, minCheckpoint, "%d %d", r.Checkpoint, minCheckpoint) } + +func TestOwnerDropped(t *testing.T) { + ctx := context.Background() + c := createFakeCluster(t, 4, false) + c.splitAndScatter("01", "02", "022", "023", "033", "04", "043") + installSubscribeSupport(c) + env := &testEnv{testCtx: t, fakeCluster: c} + fp := "github.com/pingcap/tidb/br/pkg/streamhelper/get_subscriber" + defer func() { + if t.Failed() { + fmt.Println(c) + } + }() + + adv := streamhelper.NewCheckpointAdvancer(env) + adv.OnStart(ctx) + adv.SpawnSubscriptionHandler(ctx) + require.NoError(t, adv.OnTick(ctx)) + failpoint.Enable(fp, "pause") + ch := make(chan struct{}) + go func() { + defer close(ch) + require.NoError(t, adv.OnTick(ctx)) + }() + adv.OnStop() + failpoint.Disable(fp) + + cp := c.advanceCheckpoints() + c.flushAll() + <-ch + adv.WithCheckpoints(func(vsf *spans.ValueSortedFull) { + // Advancer will manually poll the checkpoint... + require.Equal(t, vsf.MinValue(), cp) + }) +} diff --git a/br/pkg/streamhelper/flush_subscriber.go b/br/pkg/streamhelper/flush_subscriber.go index 34148f11d2a2f..af34a1cba59d4 100644 --- a/br/pkg/streamhelper/flush_subscriber.go +++ b/br/pkg/streamhelper/flush_subscriber.go @@ -230,8 +230,10 @@ func (s *subscription) doConnect(ctx context.Context, dialer LogBackupService) e cancel() return errors.Annotate(err, "failed to subscribe events") } + lcx := logutil.ContextWithField(cx, zap.Uint64("store-id", s.storeID), + zap.String("category", "log backup flush subscriber")) s.cancel = cancel - s.background = spawnJoinable(func() { s.listenOver(cli) }) + s.background = spawnJoinable(func() { s.listenOver(lcx, cli) }) return nil } @@ -244,15 +246,16 @@ func (s *subscription) close() { // because it is a ever-sharing channel. } -func (s *subscription) listenOver(cli eventStream) { +func (s *subscription) listenOver(ctx context.Context, cli eventStream) { storeID := s.storeID - log.Info("[log backup flush subscriber] Listen starting.", zap.Uint64("store", storeID)) + logutil.CL(ctx).Info("Listen starting.", zap.Uint64("store", storeID)) for { // Shall we use RecvMsg for better performance? // Note that the spans.Full requires the input slice be immutable. msg, err := cli.Recv() if err != nil { - log.Info("[log backup flush subscriber] Listen stopped.", zap.Uint64("store", storeID), logutil.ShortError(err)) + logutil.CL(ctx).Info("Listen stopped.", + zap.Uint64("store", storeID), logutil.ShortError(err)) if err == io.EOF || err == context.Canceled || status.Code(err) == codes.Canceled { return } @@ -263,13 +266,13 @@ func (s *subscription) listenOver(cli eventStream) { for _, m := range msg.Events { start, err := decodeKey(m.StartKey) if err != nil { - log.Warn("start key not encoded, skipping", + logutil.CL(ctx).Warn("start key not encoded, skipping", logutil.Key("event", m.StartKey), logutil.ShortError(err)) continue } end, err := decodeKey(m.EndKey) if err != nil { - log.Warn("end key not encoded, skipping", + logutil.CL(ctx).Warn("end key not encoded, skipping", logutil.Key("event", m.EndKey), logutil.ShortError(err)) continue }