Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions rwmutexmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package locker

import (
"sync"
"sync/atomic"
)

// RWMutexMap is a more convenient map[T]sync.RWMutex. It automatically makes
// and deletes mutexes as needed. Unlocked mutexes consume no memory.
//
// The zero value is a valid MutexMap.
type RWMutexMap[T comparable] struct {
mu sync.Mutex
locks map[T]*rwlockCtr
}

// rwlockCtr is used by RWLocker to represent a lock with a given key.
type rwlockCtr struct {
sync.RWMutex
waiters atomic.Int32 // Number of callers waiting to acquire the lock
readers atomic.Int32 // Number of readers currently holding the lock
}

var rwlockCtrPool = sync.Pool{New: func() any { return new(rwlockCtr) }}

func (l *RWMutexMap[T]) get(key T) *rwlockCtr {
if l.locks == nil {
l.locks = make(map[T]*rwlockCtr)
}

nameLock, exists := l.locks[key]
if !exists {
nameLock = rwlockCtrPool.Get().(*rwlockCtr)
l.locks[key] = nameLock
}
return nameLock
}

// Lock locks the RWMutex identified by key for writing.
func (l *RWMutexMap[T]) Lock(key T) {
l.mu.Lock()
nameLock := l.get(key)

// Increment the nameLock waiters while inside the main mutex.
// This makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently.
nameLock.waiters.Add(1)
l.mu.Unlock()

// Lock the nameLock outside the main mutex so we don't block other operations.
// Once locked then we can decrement the number of waiters for this lock.
nameLock.Lock()
nameLock.waiters.Add(-1)
}

// RLock locks the RWMutex identified by key for reading.
func (l *RWMutexMap[T]) RLock(key T) {
l.mu.Lock()
nameLock := l.get(key)

nameLock.waiters.Add(1)
l.mu.Unlock()

nameLock.RLock()
// Increment the number of readers before decrementing the waiters
// so concurrent calls to RUnlock will not see a glitch where both
// waiters and readers are 0.
nameLock.readers.Add(1)
nameLock.waiters.Add(-1)
}

// Unlock unlocks the RWMutex identified by key.
//
// It is a run-time error if the lock is not locked for writing on entry to Unlock.
func (l *RWMutexMap[T]) Unlock(key T) {
l.mu.Lock()
defer l.mu.Unlock()
nameLock := l.get(key)
// We don't have to do anything special to handle the error case:
// l.get(key) will return an unlocked mutex.

if nameLock.waiters.Load() <= 0 && nameLock.readers.Load() <= 0 {
delete(l.locks, key)
defer rwlockCtrPool.Put(nameLock)
}
nameLock.Unlock()
}

// RUnlock unlocks the RWMutex identified by key for reading.
//
// It is a run-time error if the lock is not locked for reading on entry to RUnlock.
func (l *RWMutexMap[T]) RUnlock(key T) {
l.mu.Lock()
defer l.mu.Unlock()
nameLock := l.get(key)
nameLock.readers.Add(-1)

if nameLock.waiters.Load() <= 0 && nameLock.readers.Load() <= 0 {
delete(l.locks, key)
defer rwlockCtrPool.Put(nameLock)
}
nameLock.RUnlock()
}

// Locker returns a [sync.Locker] interface that implements
// the [sync.Locker.Lock] and [sync.Locker.Unlock] methods
// by calling l.Lock(name) and l.Unlock(name).
func (l *RWMutexMap[T]) Locker(key T) sync.Locker {
return nameRWLocker[T]{l: l, key: key}
}

// RLocker returns a [sync.Locker] interface that implements
// the [sync.Locker.Lock] and [sync.Locker.Unlock] methods
// by calling l.RLock(name) and l.RUnlock(name).
func (l *RWMutexMap[T]) RLocker(key T) sync.Locker {
return nameRLocker[T]{l: l, key: key}
}

type nameRWLocker[T comparable] struct {
l *RWMutexMap[T]
key T
}
type nameRLocker[T comparable] nameRWLocker[T]

func (n nameRWLocker[T]) Lock() {
n.l.Lock(n.key)
}
func (n nameRWLocker[T]) Unlock() {
n.l.Unlock(n.key)
}

func (n nameRLocker[T]) Lock() {
n.l.RLock(n.key)
}
func (n nameRLocker[T]) Unlock() {
n.l.RUnlock(n.key)
}
189 changes: 189 additions & 0 deletions rwmutexmap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package locker

import (
"math/rand"
"strconv"
"sync"
"testing"
"time"
)

func TestRWMutex_Lock(t *testing.T) {
var l RWMutexMap[string]
l.Lock("test")
ctr := l.locks["test"]

if w := ctr.waiters.Load(); w != 0 {
t.Fatalf("expected waiters to be 0, got %d", w)
}

chDone := make(chan struct{})
go func() {
l.Lock("test")
close(chDone)
}()

chWaiting := make(chan struct{})
go func() {
for range time.Tick(1 * time.Millisecond) {
if ctr.waiters.Load() == 1 {
close(chWaiting)
break
}
}
}()

select {
case <-chWaiting:
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for lock waiters to be incremented")
}

select {
case <-chDone:
t.Fatal("lock should not have returned while it was still held")
default:
}

l.Unlock("test")

select {
case <-chDone:
case <-time.After(3 * time.Second):
t.Fatalf("lock should have completed")
}

if w := ctr.waiters.Load(); w != 0 {
t.Fatalf("expected waiters to be 0, got %d", w)
}
}

func TestRWMutex_Unlock(t *testing.T) {
var l RWMutexMap[string]

l.Lock("test")
l.Unlock("test")

chDone := make(chan struct{})
go func() {
l.Lock("test")
close(chDone)
}()

select {
case <-chDone:
case <-time.After(3 * time.Second):
t.Fatalf("lock should not be blocked")
}
}

func TestRWMutex_RLock(t *testing.T) {
var l RWMutexMap[string]
rlocked := make(chan bool, 1)
wlocked := make(chan bool, 1)
n := 10

go func() {
for i := 0; i < n; i++ {
l.RLock("test")
l.RLock("test")
rlocked <- true
l.Lock("test")
wlocked <- true
}
}()

for i := 0; i < n; i++ {
<-rlocked
l.RUnlock("test")
select {
case <-wlocked:
t.Fatal("RLock() didn't block Lock()")
default:
}
l.RUnlock("test")
<-wlocked
select {
case <-rlocked:
t.Fatal("Lock() didn't block RLock()")
default:
}
l.Unlock("test")
}

if len(l.locks) != 0 {
t.Fatalf("expected no locks to be present in the map, got %d", len(l.locks))
}
}

func TestRWMutex_Concurrency(t *testing.T) {
var l RWMutexMap[string]

var wg sync.WaitGroup
for i := 0; i <= 10000; i++ {
wg.Add(1)
go func() {
l.Lock("test")
// if there is a concurrency issue, will very likely panic here
l.Unlock("test")
l.RLock("test")
l.RUnlock("test")
wg.Done()
}()
}

chDone := make(chan struct{})
go func() {
wg.Wait()
close(chDone)
}()

select {
case <-chDone:
case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for locks to complete")
}

// Since everything has unlocked this should not exist anymore
if ctr, exists := l.locks["test"]; exists {
t.Fatalf("lock should not exist: %v", ctr)
}
}

func BenchmarkRWMutex(b *testing.B) {
var l RWMutexMap[string]
b.ReportAllocs()
for i := 0; i < b.N; i++ {
l.Lock("test")
l.Lock(strconv.Itoa(i))
l.Unlock(strconv.Itoa(i))
l.Unlock("test")
}
}

func BenchmarkRWMutex_Parallel(b *testing.B) {
var l RWMutexMap[string]
b.SetParallelism(128)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
l.Lock("test")
l.Unlock("test")
}
})
}

func BenchmarkRWMutex_MoreKeys(b *testing.B) {
var l RWMutexMap[string]
var keys []string
for i := 0; i < 64; i++ {
keys = append(keys, strconv.Itoa(i))
}
b.SetParallelism(128)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
k := keys[rand.Intn(len(keys))]
l.Lock(k)
l.Unlock(k)
}
})
}
Loading