Skip to content

Commit

Permalink
Merge pull request #7699 from dolthub/nicktobey/pertablelocking
Browse files Browse the repository at this point in the history
Add per-table locking for AutoIncrementTracker
  • Loading branch information
nicktobey authored Apr 9, 2024
2 parents 8501491 + fef14f3 commit 02b3213
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 36 deletions.
75 changes: 39 additions & 36 deletions go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess/mutexmap"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
"github.com/dolthub/dolt/go/store/prolly/tree"
"github.com/dolthub/dolt/go/store/types"
)

type AutoIncrementTracker struct {
dbName string
sequences map[string]uint64
mu *sync.Mutex
sequences *sync.Map // map[string]uint64
mm *mutexmap.MutexMap
}

var _ globalstate.AutoIncrementTracker = AutoIncrementTracker{}
Expand All @@ -48,8 +49,8 @@ var _ globalstate.AutoIncrementTracker = AutoIncrementTracker{}
func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb.Rootish) (AutoIncrementTracker, error) {
ait := AutoIncrementTracker{
dbName: dbName,
sequences: make(map[string]uint64),
mu: &sync.Mutex{},
sequences: &sync.Map{},
mm: mutexmap.NewMutexMap(),
}

for _, root := range roots {
Expand All @@ -71,8 +72,9 @@ func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb
return true, err
}

if seq > ait.sequences[tableName] {
ait.sequences[tableName] = seq
oldValue, loaded := ait.sequences.LoadOrStore(tableName, seq)
if loaded && seq > oldValue.(uint64) {
ait.sequences.Store(tableName, seq)
}

return false, nil
Expand All @@ -82,41 +84,46 @@ func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb
return AutoIncrementTracker{}, err
}
}

return ait, nil
}

func loadAutoIncValue(sequences *sync.Map, tableName string) uint64 {
tableName = strings.ToLower(tableName)
current, hasCurrent := sequences.Load(tableName)
if !hasCurrent {
return 0
}
return current.(uint64)
}

// Current returns the next value to be generated in the auto increment sequence for the table named
func (a AutoIncrementTracker) Current(tableName string) uint64 {
a.mu.Lock()
defer a.mu.Unlock()
return a.sequences[strings.ToLower(tableName)]
return loadAutoIncValue(a.sequences, tableName)
}

// Next returns the next auto increment value for the table named using the provided value from an insert (which may
// be null or 0, in which case it will be generated from the sequence).
func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
a.mu.Lock()
defer a.mu.Unlock()

tbl = strings.ToLower(tbl)

given, err := CoerceAutoIncrementValue(insertVal)
if err != nil {
return 0, err
}

curr := a.sequences[tbl]
release := a.mm.Lock(tbl)
defer release()

curr := loadAutoIncValue(a.sequences, tbl)

if given == 0 {
// |given| is 0 or NULL
a.sequences[tbl]++
a.sequences.Store(tbl, curr+1)
return curr, nil
}

if given >= curr {
a.sequences[tbl] = given
a.sequences[tbl]++
a.sequences.Store(tbl, given+1)
return given, nil
}

Expand Down Expand Up @@ -152,14 +159,14 @@ func CoerceAutoIncrementValue(val interface{}) (uint64, error) {
// table. Otherwise, the update is silently disregarded. So far this matches the MySQL behavior, but Dolt uses the
// maximum value for this table across all branches.
func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
a.mu.Lock()
defer a.mu.Unlock()

tableName = strings.ToLower(tableName)

existing := a.sequences[tableName]
release := a.mm.Lock(tableName)
defer release()

existing := loadAutoIncValue(a.sequences, tableName)
if newAutoIncVal > existing {
a.sequences[strings.ToLower(tableName)] = newAutoIncVal
a.sequences.Store(tableName, newAutoIncVal)
return table.SetAutoIncrementValue(ctx, newAutoIncVal)
} else {
// If the value is not greater than the current tracker, we have more work to do
Expand Down Expand Up @@ -310,7 +317,7 @@ func (a AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table
}
}

a.sequences[tableName] = maxAutoInc
a.sequences.Store(tableName, maxAutoInc)
return table, nil
}

Expand Down Expand Up @@ -351,27 +358,21 @@ func getMaxIndexValue(ctx context.Context, indexData durable.Index) (uint64, err

// AddNewTable initializes a new table with an auto increment column to the tracker, as necessary
func (a AutoIncrementTracker) AddNewTable(tableName string) {
a.mu.Lock()
defer a.mu.Unlock()

tableName = strings.ToLower(tableName)
// only initialize the sequence for this table if no other branch has such a table
if _, ok := a.sequences[tableName]; !ok {
a.sequences[tableName] = uint64(1)
}
a.sequences.LoadOrStore(tableName, uint64(1))
}

// DropTable drops the table with the name given.
// To establish the new auto increment value, callers must also pass all other working sets in scope that may include
// a table with the same name, omitting the working set that just deleted the table named.
func (a AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error {
a.mu.Lock()
defer a.mu.Unlock()

tableName = strings.ToLower(tableName)

// reset sequence to the minimum value
a.sequences[tableName] = 1
release := a.mm.Lock(tableName)
defer release()

newHighestValue := uint64(1)

// Get the new highest value from all tables in the working sets given
for _, ws := range wses {
Expand All @@ -395,11 +396,13 @@ func (a AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses
return err
}

if seq > a.sequences[tableName] {
a.sequences[tableName] = seq
if seq > newHighestValue {
newHighestValue = seq
}
}
}

a.sequences.Store(tableName, newHighestValue)

return nil
}
65 changes: 65 additions & 0 deletions go/libraries/doltcore/sqle/dsess/mutexmap/mutexmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mutexmap

import (
"sync"
)

// MutexMap holds a dynamic number of mutexes identified by keys. When a mutex is no longer needed, it's removed from
// the map.
type MutexMap struct {
mu sync.Mutex // Access to the map itself must be synchronized.
keyedMutexes map[interface{}]*mapMutex
}

type mapMutex struct {
key interface{}
mu sync.Mutex
parent *MutexMap
refcount int
}

func NewMutexMap() *MutexMap {
return &MutexMap{keyedMutexes: make(map[interface{}]*mapMutex)}
}

func (mm *MutexMap) Lock(key interface{}) func() {
mm.mu.Lock()
defer mm.mu.Unlock()

keyedMutex, hasKey := mm.keyedMutexes[key]
if !hasKey {
keyedMutex = &mapMutex{parent: mm, key: key}
mm.keyedMutexes[key] = keyedMutex
}
keyedMutex.refcount++

keyedMutex.mu.Lock()

return func() { keyedMutex.Unlock() }
}

func (mm *mapMutex) Unlock() {
mutexMap := mm.parent
mutexMap.mu.Lock()
defer mutexMap.mu.Unlock()

mm.refcount--
if mm.refcount < 1 {
delete(mutexMap.keyedMutexes, mm.key)
}
mm.mu.Unlock()
}

0 comments on commit 02b3213

Please sign in to comment.