Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .env
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
export REDIS_PASSWORD=redispassword
export TIMEOUT=1s
export REDIS_PASSWORD="${REDIS_PASSWORD:-redispassword}"
export REDIS_PORT="${REDIS_PORT:-6379}"
export REDIS_ADDR="${REDIS_ADDR:-localhost:${REDIS_PORT}}"
export TIMEOUT="${TIMEOUT:-1s}"
2 changes: 1 addition & 1 deletion docker-compose/docker-compose-redis.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ services:
container_name: "pulse-redis"
command: redis-server --save "" --loglevel warning --requirepass ${REDIS_PASSWORD}
ports:
- "6379:6379"
- "${REDIS_PORT:-6379}:6379"
99 changes: 81 additions & 18 deletions rmap/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rmap
import (
"context"
"encoding/binary"
"errors"
"fmt"
"math/rand"
"regexp"
Expand All @@ -26,6 +27,8 @@ type (
msgch <-chan *redis.Message // channel to receive map updates
chans []chan EventKind // channels to send notifications
done chan struct{} // channel to signal shutdown
closectx context.Context // context canceled by Close
closer context.CancelFunc // cancels closectx
wait sync.WaitGroup // wait for read goroutine to exit
logger pulse.Logger // logger
sub *redis.PubSub // subscription to map updates
Expand Down Expand Up @@ -97,12 +100,21 @@ func Join(ctx context.Context, name string, rdb *redis.Client, opts ...MapOption
if ctx.Err() != nil {
return nil, ctx.Err()
}
if rdb == nil {
return nil, fmt.Errorf("pulse map: %s Redis client cannot be nil", name)
}
o := parseOptions(opts...)
if o.Logger == nil {
o.Logger = pulse.NoopLogger()
}
closectx, closer := context.WithCancel(context.Background())
sm := &Map{
Name: name,
chankey: fmt.Sprintf("map:%s:updates", name),
hashkey: fmt.Sprintf("map:%s:content", name),
done: make(chan struct{}),
closectx: closectx,
closer: closer,
logger: o.Logger.WithPrefix("map", name),
rdb: rdb,
content: make(map[string]string),
Expand Down Expand Up @@ -508,6 +520,18 @@ func (sm *Map) Reset(ctx context.Context) error {
// only if the "color" key has value "blue" and the "size" key has value "large",
// and return true if the map was cleared.
func (sm *Map) TestAndReset(ctx context.Context, keys, tests []string) (bool, error) {
if len(keys) != len(tests) {
return false, fmt.Errorf("pulse map: %s TestAndReset requires len(keys) == len(tests)", sm.Name)
}
for _, k := range keys {
if len(k) == 0 {
return false, fmt.Errorf("pulse map: %s key cannot be empty in %q", sm.Name, "testAndReset")
}
if strings.Contains(k, "=") {
return false, fmt.Errorf("pulse map: %s key %q cannot contain '=' in %q", sm.Name, k, "testAndReset")
}
}

args := make([]any, 1+len(keys)+len(tests))
args[0] = "*"
for i, k := range keys {
Expand All @@ -534,6 +558,10 @@ func (sm *Map) Close() {
sm.closing = true
sm.lock.Unlock()

if sm.closer != nil {
sm.closer()
}

// Signal run() to stop and wait for it to complete
close(sm.done)
sm.wait.Wait()
Expand Down Expand Up @@ -578,6 +606,7 @@ func (sm *Map) init(ctx context.Context) error {
sm.sub = sm.rdb.Subscribe(ctx, sm.chankey)
_, err := sm.sub.Receive(ctx) // Fail fast if we can't subscribe.
if err != nil {
_ = sm.sub.Close()
return fmt.Errorf("pulse map: %s failed to join: %w", sm.Name, err)
}
sm.msgch = sm.sub.Channel()
Expand All @@ -589,6 +618,8 @@ func (sm *Map) init(ctx context.Context) error {
// local copy with the same data.
cmd := sm.rdb.HGetAll(ctx, sm.hashkey)
if err := cmd.Err(); err != nil {
_ = sm.sub.Unsubscribe(ctx, sm.chankey)
_ = sm.sub.Close()
return fmt.Errorf("pulse map: %s failed to read initial content: %w", sm.Name, err)
}
sm.content = cmd.Val()
Expand All @@ -599,6 +630,7 @@ func (sm *Map) init(ctx context.Context) error {
// run updates the local copy of the replicated map whenever a remote update is
// received and sends notifications when needed.
func (sm *Map) run() {
defer sm.wait.Done()
for {
select {
case msg, ok := <-sm.msgch:
Expand Down Expand Up @@ -672,8 +704,6 @@ func (sm *Map) run() {
if waiter.key == notification.key && waiter.value == notification.value {
select {
case waiter.ch <- *notification:
case <-sm.done:
return
case <-waiter.ctx.Done():
// Waiter was cancelled or timed out
continue
Expand All @@ -691,13 +721,16 @@ func (sm *Map) run() {
for _, c := range sm.chans {
close(c)
}
if err := sm.sub.Unsubscribe(context.Background(), sm.chankey); err != nil {
sm.logger.Error(fmt.Errorf("failed to unsubscribe: %w", err))
}
if err := sm.sub.Close(); err != nil {
sm.logger.Error(fmt.Errorf("failed to close subscription: %w", err))
if sm.sub != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := sm.sub.Unsubscribe(ctx, sm.chankey); err != nil {
sm.logger.Error(fmt.Errorf("failed to unsubscribe: %w", err))
}
cancel()
if err := sm.sub.Close(); err != nil {
sm.logger.Error(fmt.Errorf("failed to close subscription: %w", err))
}
}
sm.wait.Done()
return
}
}
Expand All @@ -719,7 +752,7 @@ func (sm *Map) runLuaScript(ctx context.Context, name string, script *redis.Scri
if strings.Contains(key, "=") {
return nil, fmt.Errorf("pulse map: %s key %q cannot contain '=' in %q", sm.Name, key, name)
}
res, err := script.EvalSha(ctx, sm.rdb, []string{sm.hashkey, sm.chankey}, args...).Result()
res, err := script.Run(ctx, sm.rdb, []string{sm.hashkey, sm.chankey}, args...).Result()
if err != nil && err != redis.Nil {
return nil, fmt.Errorf("pulse map: %s failed to run %q for key %s: %w", sm.Name, name, key, err)
}
Expand All @@ -733,23 +766,53 @@ func (sm *Map) reconnect() {
for {
count++
sm.logger.Info("reconnect", "attempt", count)
sm.lock.Lock()
if sm.closing {
sm.lock.Unlock()
if sm.closectx.Err() != nil {
return
}
sm.lock.RLock()
closing := sm.closing
sm.lock.RUnlock()
if closing {
return
}
sm.sub = sm.rdb.Subscribe(context.Background(), sm.chankey)
_, err := sm.sub.Receive(context.Background())

sub := sm.rdb.Subscribe(sm.closectx, sm.chankey)
_, err := sub.Receive(sm.closectx)
if err != nil {
sm.lock.Unlock()
_ = sub.Close()
if errors.Is(err, context.Canceled) {
return
}
sm.logger.Error(fmt.Errorf("failed to reconnect: %w", err), "attempt", count)
time.Sleep(time.Duration(rand.Float64()*5+1) * time.Second)
sleep := time.Duration(rand.Float64()*5+1) * time.Second
timer := time.NewTimer(sleep)
select {
case <-timer.C:
case <-sm.closectx.Done():
timer.Stop()
return
}
continue
}
sm.msgch = sm.sub.Channel()

msgch := sub.Channel()

sm.lock.Lock()
if sm.closing {
sm.lock.Unlock()
_ = sub.Close()
return
}
oldSub := sm.sub
sm.sub = sub
sm.msgch = msgch
sm.lock.Unlock()
if oldSub != nil {
_ = oldSub.Close()
}

sm.logger.Info("reconnected")
break
return
}
}

Expand Down
40 changes: 22 additions & 18 deletions rmap/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,24 @@ import (
)

var (
redisPwd = "redispassword"
wf = time.Second
tck = time.Millisecond
redisPwd = "redispassword"
redisAddr = "localhost:6379"
wf = time.Second
tck = time.Millisecond
)

func init() {
if p := os.Getenv("REDIS_PASSWORD"); p != "" {
redisPwd = p
}
if a := os.Getenv("REDIS_ADDR"); a != "" {
redisAddr = a
}
}

func TestMapLocal(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Addr: redisAddr,
Password: redisPwd,
})
var buf Buffer
Expand Down Expand Up @@ -139,7 +143,7 @@ func TestMapLocal(t *testing.T) {

func TestSetAndWait(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Addr: redisAddr,
Password: redisPwd,
})
var buf Buffer
Expand Down Expand Up @@ -200,7 +204,7 @@ func TestSetAndWait(t *testing.T) {

func TestReadAfterClose(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Addr: redisAddr,
Password: redisPwd,
})
var buf Buffer
Expand Down Expand Up @@ -265,7 +269,7 @@ func TestReadAfterClose(t *testing.T) {
}

func TestWriteEmptyString(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx := context.Background()
m, err := Join(ctx, "test", rdb)
require.NoError(t, err)
Expand All @@ -285,7 +289,7 @@ func TestWriteEmptyString(t *testing.T) {
}

func TestAppendUniqueValues(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx := context.Background()
m, err := Join(ctx, "test", rdb)
require.NoError(t, err)
Expand Down Expand Up @@ -329,7 +333,7 @@ func TestAppendUniqueValues(t *testing.T) {
}

func TestTestAndDelete(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx := context.Background()
m, err := Join(ctx, "test", rdb)
require.NoError(t, err)
Expand Down Expand Up @@ -367,7 +371,7 @@ func TestTestAndDelete(t *testing.T) {
}

func TestTestAndSet(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx := context.Background()
m, err := Join(ctx, "test", rdb)
require.NoError(t, err)
Expand All @@ -392,7 +396,7 @@ func TestTestAndSet(t *testing.T) {
}

func TestTestAndReset(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx := context.Background()
m, err := Join(ctx, "test", rdb)
require.NoError(t, err)
Expand Down Expand Up @@ -448,7 +452,7 @@ func TestTestAndReset(t *testing.T) {
}

func TestArrays(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx := context.Background()
m, err := Join(ctx, "test", rdb)
require.NoError(t, err)
Expand Down Expand Up @@ -484,7 +488,7 @@ func TestArrays(t *testing.T) {
}

func TestIncrement(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx := context.Background()
m, err := Join(ctx, "test", rdb)
require.NoError(t, err)
Expand All @@ -510,7 +514,7 @@ func TestIncrement(t *testing.T) {
}

func TestLogs(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
var buf Buffer
ctx := context.Background()
ctx = log.Context(ctx, log.WithOutput(&buf), log.WithDebug(), log.WithFormat(log.FormatText))
Expand Down Expand Up @@ -562,7 +566,7 @@ func TestJoinErrors(t *testing.T) {
}

func TestSetErrors(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -580,7 +584,7 @@ func TestSetErrors(t *testing.T) {
}

func TestAppendValuesErrors(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx := context.Background()

m, err := Join(ctx, "test", rdb)
Expand All @@ -596,7 +600,7 @@ func TestAppendValuesErrors(t *testing.T) {
}

func TestRemoveValuesErrors(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx := context.Background()

m, err := Join(ctx, "test", rdb)
Expand All @@ -614,7 +618,7 @@ func TestRemoveValuesErrors(t *testing.T) {
}

func TestReconnect(t *testing.T) {
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: redisPwd})
rdb := redis.NewClient(&redis.Options{Addr: redisAddr, Password: redisPwd})
ctx := context.Background()
var buf Buffer
ctx = log.Context(ctx, log.WithOutput(&buf))
Expand Down
2 changes: 2 additions & 0 deletions scripts/test
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ echo "Running tests..."
GIT_ROOT=$(git rev-parse --show-toplevel)
pushd ${GIT_ROOT}

source .env

staticcheck ./...

# If --force is passed, add --count=1 to the go test command
Expand Down
Loading