diff --git a/cdc/sink/common/flow_control.go b/cdc/sink/common/flow_control.go index 8f650a4f24f..47ad19fc6b7 100644 --- a/cdc/sink/common/flow_control.go +++ b/cdc/sink/common/flow_control.go @@ -60,25 +60,28 @@ func (c *TableMemoryQuota) ConsumeWithBlocking(nBytes uint64, blockCallBack func return cerrors.ErrFlowControllerEventLargerThanQuota.GenWithStackByArgs(nBytes, c.Quota) } + c.mu.Lock() + if c.Consumed+nBytes >= c.Quota { + c.mu.Unlock() + err := blockCallBack() + if err != nil { + return errors.Trace(err) + } + } else { + c.mu.Unlock() + } + c.mu.Lock() defer c.mu.Unlock() - calledBack := false for { if atomic.LoadUint32(&c.IsAborted) == 1 { return cerrors.ErrFlowControllerAborted.GenWithStackByArgs() } + if c.Consumed+nBytes < c.Quota { break } - - if !calledBack { - calledBack = true - err := blockCallBack() - if err != nil { - return errors.Trace(err) - } - } c.cond.Wait() } diff --git a/cdc/sink/common/flow_control_test.go b/cdc/sink/common/flow_control_test.go index a31fbf3432a..2714dde80b5 100644 --- a/cdc/sink/common/flow_control_test.go +++ b/cdc/sink/common/flow_control_test.go @@ -417,6 +417,48 @@ func (s *flowControlSuite) TestFlowControlCallBack(c *check.C) { c.Assert(atomic.LoadUint64(&consumedBytes), check.Equals, uint64(0)) } +func (s *flowControlSuite) TestFlowControlCallBackNotBlockingRelease(c *check.C) { + defer testleak.AfterTest(c)() + + var wg sync.WaitGroup + controller := NewTableFlowController(512) + wg.Add(1) + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + go func() { + defer wg.Done() + err := controller.Consume(1, 511, func() error { + c.Fatalf("unreachable") + return nil + }) + c.Assert(err, check.IsNil) + + var isBlocked int32 + wg.Add(1) + go func() { + defer wg.Done() + <-time.After(time.Second * 1) + // makes sure that this test case is valid + c.Assert(atomic.LoadInt32(&isBlocked), check.Equals, int32(1)) + controller.Release(1) + cancel() + }() + + err = controller.Consume(2, 511, func() error { + atomic.StoreInt32(&isBlocked, 1) + <-ctx.Done() + atomic.StoreInt32(&isBlocked, 0) + return ctx.Err() + }) + + c.Assert(err, check.ErrorMatches, ".*context canceled.*") + }() + + wg.Wait() +} + func (s *flowControlSuite) TestFlowControlCallBackError(c *check.C) { defer testleak.AfterTest(c)()