Skip to content

Commit

Permalink
*: add a reference count for StmtCtx
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu committed Nov 24, 2022
1 parent f50113f commit 68e46cb
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 6 deletions.
1 change: 1 addition & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu
Info: sql,
CurTxnStartTS: curTxnStartTS,
StmtCtx: s.sessionVars.StmtCtx,
RefCountOfStmtCtx: &s.sessionVars.RefCountOfStmtCtx,
MemTracker: s.sessionVars.MemTracker,
DiskTracker: s.sessionVars.DiskTracker,
StatsInfo: plannercore.GetStatsInfo,
Expand Down
36 changes: 36 additions & 0 deletions sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,42 @@ func (warn *SQLWarn) UnmarshalJSON(data []byte) error {
return nil
}

type ReferenceCount int32

const (
ReferenceCountIsFrozen int32 = -1
ReferenceCountNoReference int32 = 0
)

// TryIncrease tries to increase the reference count.
// There is a small chance that TryIncrease returns true while TryFreeze and
// UnFreeze are invoked successfully during the execution of TryIncrease.
func (rf *ReferenceCount) TryIncrease() bool {
refCnt := atomic.LoadInt32((*int32)(rf))
for ; refCnt != ReferenceCountIsFrozen && !atomic.CompareAndSwapInt32((*int32)(rf), refCnt, refCnt+1); refCnt = atomic.LoadInt32((*int32)(rf)) {
}
if refCnt == ReferenceCountIsFrozen {
return false
}
return true
}

// Decrease decreases the reference count.
func (rf *ReferenceCount) Decrease() {
for refCnt := atomic.LoadInt32((*int32)(rf)); !atomic.CompareAndSwapInt32((*int32)(rf), refCnt, refCnt-1); refCnt = atomic.LoadInt32((*int32)(rf)) {
}
}

// TryFreeze tries to freeze the StmtCtx to frozen before resetting the old StmtCtx.
func (rf *ReferenceCount) TryFreeze() bool {
return atomic.LoadInt32((*int32)(rf)) == ReferenceCountNoReference && atomic.CompareAndSwapInt32((*int32)(rf), ReferenceCountNoReference, ReferenceCountIsFrozen)
}

// UnFreeze unfreeze the frozen StmtCtx thus the other session can access this StmtCtx.
func (rf *ReferenceCount) UnFreeze() {
atomic.StoreInt32((*int32)(rf), ReferenceCountNoReference)
}

// StatementContext contains variables for a statement.
// It should be reset before executing a statement.
type StatementContext struct {
Expand Down
12 changes: 11 additions & 1 deletion sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,11 @@ type SessionVars struct {
// StmtCtx holds variables for current executing statement.
StmtCtx *stmtctx.StatementContext

// RefCountOfStmtCtx indicates the reference count of StmtCtx. When the
// StmtCtx is accessed by other sessions, e.g. oom-alarm-handler/expensive-query-handler, add one first.
// Note: this variable should be accessed and updated by atomic operations.
RefCountOfStmtCtx stmtctx.ReferenceCount

// AllowAggPushDown can be set to false to forbid aggregation push down.
AllowAggPushDown bool

Expand Down Expand Up @@ -1383,7 +1388,12 @@ func (s *SessionVars) InitStatementContext() *stmtctx.StatementContext {
if sc == s.StmtCtx {
sc = &s.cachedStmtCtx[1]
}
*sc = stmtctx.StatementContext{}
if s.RefCountOfStmtCtx.TryFreeze() {
*sc = stmtctx.StatementContext{}
s.RefCountOfStmtCtx.UnFreeze()
} else {
sc = &stmtctx.StatementContext{}
}
return sc
}

Expand Down
6 changes: 5 additions & 1 deletion util/expensivequery/expensivequery.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,9 @@ func (eqh *Handle) LogOnQueryExceedMemQuota(connID uint64) {

// logExpensiveQuery logs the queries which exceed the time threshold or memory threshold.
func logExpensiveQuery(costTime time.Duration, info *util.ProcessInfo, msg string) {
logutil.BgLogger().Warn(msg, util.GenLogFields(costTime, info, true)...)
fields := util.GenLogFields(costTime, info, true)
if fields == nil {
return
}
logutil.BgLogger().Warn(msg, fields...)
}
9 changes: 5 additions & 4 deletions util/memoryusagealarm/memoryusagealarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,15 @@ func (record *memoryUsageAlarm) printTop10SqlInfo(pinfo []*util.ProcessInfo, f *
func (record *memoryUsageAlarm) getTop10SqlInfo(cmp func(i, j *util.ProcessInfo) bool, pinfo []*util.ProcessInfo) strings.Builder {
slices.SortFunc(pinfo, cmp)
list := pinfo
if len(list) > 10 {
list = list[:10]
}
var buf strings.Builder
oomAction := variable.OOMAction.Load()
serverMemoryLimit := memory.ServerMemoryLimit.Load()
for i, info := range list {
for i, info, totalCnt := 0, list[0], 10; i < len(list) && totalCnt > 0; i++ {
buf.WriteString(fmt.Sprintf("SQL %v: \n", i))
fields := util.GenLogFields(record.lastCheckTime.Sub(info.Time), info, false)
if fields == nil {
continue
}
fields = append(fields, zap.String("tidb_mem_oom_action", oomAction))
fields = append(fields, zap.Uint64("tidb_server_memory_limit", serverMemoryLimit))
fields = append(fields, zap.Int64("tidb_mem_quota_query", info.OOMAlarmVariablesInfo.SessionMemQuotaQuery))
Expand All @@ -294,6 +294,7 @@ func (record *memoryUsageAlarm) getTop10SqlInfo(cmp func(i, j *util.ProcessInfo)
}
buf.WriteString("\n")
}
totalCnt--
}
buf.WriteString("\n")
return buf
Expand Down
1 change: 1 addition & 0 deletions util/processinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type ProcessInfo struct {
Time time.Time
Plan interface{}
StmtCtx *stmtctx.StatementContext
RefCountOfStmtCtx *stmtctx.ReferenceCount
MemTracker *memory.Tracker
DiskTracker *disk.Tracker
StatsInfo func(interface{}) map[string]uint64
Expand Down
6 changes: 6 additions & 0 deletions util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net/http"
"strconv"
"strings"
"sync/atomic"
"time"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -106,6 +107,11 @@ func Str2Int64Map(str string) map[int64]struct{} {

// GenLogFields generate log fields.
func GenLogFields(costTime time.Duration, info *ProcessInfo, needTruncateSQL bool) []zap.Field {
if !info.RefCountOfStmtCtx.TryIncrease() {
return nil
}
defer info.RefCountOfStmtCtx.Decrease()

logFields := make([]zap.Field, 0, 20)
logFields = append(logFields, zap.String("cost_time", strconv.FormatFloat(costTime.Seconds(), 'f', -1, 64)+"s"))
execDetail := info.StmtCtx.GetExecDetails()
Expand Down

0 comments on commit 68e46cb

Please sign in to comment.