Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ type Poolable[T any] interface {
IncrementUsage()
// ResetUsage resets the usage count to 0
ResetUsage()
// SetShardIndex sets the shard index for this object
SetShardIndex(index int)
// GetShardIndex returns the shard index for this object
GetShardIndex() int
}

// Fields provides intrusive fields and logic for poolable objects.
Expand All @@ -121,6 +125,7 @@ type Poolable[T any] interface {
type Fields[T any] struct {
usageCount atomic.Int64
next atomic.Pointer[T]
shardIndex int // Track which shard this object belongs to
}

// GetNext implements interface function
Expand Down Expand Up @@ -148,6 +153,16 @@ func (p *Fields[T]) ResetUsage() {
p.usageCount.Store(0)
}

// SetShardIndex sets the shard index for this object
func (p *Fields[T]) SetShardIndex(index int) {
p.shardIndex = index
}

// GetShardIndex returns the shard index for this object
func (p *Fields[T]) GetShardIndex() int {
return p.shardIndex
}

// Config holds configuration options for the pool.
type Config[T any, P Poolable[T]] struct {
// Cleanup defines the cleanup policy for the pool
Expand Down Expand Up @@ -301,7 +316,8 @@ func initShards[T any, P Poolable[T]](p *ShardedPool[T, P]) {
// Get returns an object from the pool or creates a new one.
// Returns nil if MaxPoolSize is set, reached, and no reusable objects are available.
func (p *ShardedPool[T, P]) Get() P {
shard := p.Shards[runtimeProcPin()]
shardID := runtimeProcPin()
shard := p.Shards[shardID]
runtimeProcUnpin()

// Fast path: check single object first
Expand Down Expand Up @@ -330,6 +346,7 @@ func (p *ShardedPool[T, P]) Get() P {
// Direct allocation path
if !p.cfg.Growth.Enable {
obj := P(p.cfg.Allocator())
obj.SetShardIndex(shardID)
obj.IncrementUsage()
p.CurrentPoolLength.Add(1)
return obj
Expand All @@ -340,6 +357,7 @@ func (p *ShardedPool[T, P]) Get() P {
}

obj := P(p.cfg.Allocator())
obj.SetShardIndex(shardID)
obj.IncrementUsage()
p.CurrentPoolLength.Add(1)
return obj
Expand All @@ -362,6 +380,7 @@ func (p *ShardedPool[T, P]) GetBlock() P {
// Try to allocate new one if allowed
if !p.cfg.Growth.Enable || p.CurrentPoolLength.Load() < p.cfg.Growth.MaxPoolSize {
obj := P(p.cfg.Allocator())
obj.SetShardIndex(shardID)
obj.IncrementUsage()
p.CurrentPoolLength.Add(1)
return obj
Expand All @@ -388,6 +407,16 @@ func (p *ShardedPool[T, P]) PutBlock(obj P) {
p.cfg.Cleaner(obj)
shard := p.getMostBlockedShard()

// Find the shard index for the most blocked shard
var shardID int
for i, s := range p.Shards {
if s == shard {
shardID = i
break
}
}
obj.SetShardIndex(shardID)

for {
oldHead := P(shard.Head.Load())

Expand All @@ -403,9 +432,16 @@ func (p *ShardedPool[T, P]) PutBlock(obj P) {
// This implementation creates memory, don't use it in the hot path,
// "make" always makes things much slower.
func (p *ShardedPool[T, P]) GetN(n int) []P {
shardID := runtimeProcPin()
runtimeProcUnpin()

objs := make([]P, n) // WARNING
for i := range n {
objs[i] = p.Get()
obj := p.Get()
if obj != nil {
obj.SetShardIndex(shardID)
}
objs[i] = obj
}

return objs
Expand All @@ -415,8 +451,8 @@ func (p *ShardedPool[T, P]) GetN(n int) []P {
func (p *ShardedPool[T, P]) Put(obj P) {
p.cfg.Cleaner(obj)

shard := p.Shards[runtimeProcPin()]
runtimeProcUnpin()
shardID := obj.GetShardIndex()
shard := p.Shards[shardID]

// Fast path: try single object first
if shard.Single.CompareAndSwap(nil, obj) {
Expand Down Expand Up @@ -574,6 +610,22 @@ func (p *ShardedPool[T, P]) filterUsableObjects(head P) (keptHead, keptTail P, e
}

func (p *ShardedPool[T, P]) reinsertKeptObjects(shard *Shard[T, P], keptHead, keptTail P) {
// Find the shard index for this shard
var shardID int
for i, s := range p.Shards {
if s == shard {
shardID = i
break
}
}

// Set the correct shard index for all kept objects
current := keptHead
for current != nil {
current.SetShardIndex(shardID)
current = current.GetNext()
}

for {
currentHead := P(shard.Head.Load())
if currentHead != nil {
Expand Down
Loading