Skip to content

Commit

Permalink
Prevent redis client from incorrectly choosing cluster mode with loca…
Browse files Browse the repository at this point in the history
…l address (grafana#9185)
  • Loading branch information
Danny Kopping authored Apr 19, 2023
1 parent 422560b commit 5265570
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
* [9099](https://github.com/grafana/loki/pull/9099) **salvacorts**: Fix the estimated size of chunks when writing a new TSDB file during compaction.
* [9130](https://github.com/grafana/loki/pull/9130) **salvacorts**: Pass LogQL engine options down to the _split by range_, _sharding_, and _query size limiter_ middlewares.
* [9156](https://github.com/grafana/loki/pull/9156) **ashwanthgoli**: Expiration: do not drop index if period is a zero value
* [9185](https://github.com/grafana/loki/pull/9185) **dannykopping**: Prevent redis client from incorrectly choosing cluster mode with local address.

#### Promtail

Expand Down
66 changes: 49 additions & 17 deletions pkg/storage/chunk/cache/redis_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,11 @@ type RedisClient struct {

// NewRedisClient creates Redis client
func NewRedisClient(cfg *RedisConfig) (*RedisClient, error) {
endpoints := strings.Split(cfg.Endpoint, ",")
// Handle single configuration endpoint which resolves multiple nodes.
if len(endpoints) == 1 {
host, port, err := net.SplitHostPort(endpoints[0])
if err != nil {
return nil, err
}
addrs, err := net.LookupHost(host)
if err != nil {
return nil, err
}
if len(addrs) > 1 {
endpoints = nil
for _, addr := range addrs {
endpoints = append(endpoints, net.JoinHostPort(addr, port))
}
}
endpoints, err := deriveEndpoints(cfg.Endpoint, net.LookupHost)
if err != nil {
return nil, fmt.Errorf("failed to derive endpoints: %w", err)
}

opt := &redis.UniversalOptions{
Addrs: endpoints,
MasterName: cfg.MasterName,
Expand All @@ -96,6 +83,51 @@ func NewRedisClient(cfg *RedisConfig) (*RedisClient, error) {
}, nil
}

func deriveEndpoints(endpoint string, lookup func(host string) ([]string, error)) ([]string, error) {
if lookup == nil {
return nil, fmt.Errorf("lookup function is nil")
}

endpoints := strings.Split(endpoint, ",")

// no endpoints or multiple endpoints will not need derivation
if len(endpoints) != 1 {
return endpoints, nil
}

// Handle single configuration endpoint which resolves multiple nodes.
host, port, err := net.SplitHostPort(endpoints[0])
if err != nil {
return nil, fmt.Errorf("splitting host:port failed :%w", err)
}
addrs, err := lookup(host)
if err != nil {
return nil, fmt.Errorf("could not lookup host: %w", err)
}

// only use the resolved addresses if they are not all loopback addresses;
// multiple addresses invokes cluster mode
allLoopback := allAddrsAreLoopback(addrs)
if len(addrs) > 1 && !allLoopback {
endpoints = nil
for _, addr := range addrs {
endpoints = append(endpoints, net.JoinHostPort(addr, port))
}
}

return endpoints, nil
}

func allAddrsAreLoopback(addrs []string) bool {
for _, addr := range addrs {
if !net.ParseIP(addr).IsLoopback() {
return false
}
}

return true
}

func (c *RedisClient) Ping(ctx context.Context) error {
var cancel context.CancelFunc
if c.timeout > 0 {
Expand Down
76 changes: 76 additions & 0 deletions pkg/storage/chunk/cache/redis_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package cache

import (
"context"
"fmt"
"testing"
"time"

"github.com/alicebob/miniredis/v2"
"github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -98,3 +100,77 @@ func mockRedisClientCluster() (*RedisClient, error) {
}),
}, nil
}

func Test_deriveEndpoints(t *testing.T) {
const (
upstream = "upstream"
downstream = "downstream"
lookback = "localhost"
)

tests := []struct {
name string
endpoints string
lookup func(host string) ([]string, error)
want []string
wantErr string
}{
{
name: "single endpoint",
endpoints: fmt.Sprintf("%s:6379", upstream),
lookup: func(host string) ([]string, error) {
return []string{upstream}, nil
},
want: []string{fmt.Sprintf("%s:6379", upstream)},
wantErr: "",
},
{
name: "multiple endpoints",
endpoints: fmt.Sprintf("%s:6379,%s:6379", upstream, downstream), // note the space
lookup: func(host string) ([]string, error) {
return []string{host}, nil
},
want: []string{fmt.Sprintf("%s:6379", upstream), fmt.Sprintf("%s:6379", downstream)},
wantErr: "",
},
{
name: "all loopback",
endpoints: fmt.Sprintf("%s:6379", lookback),
lookup: func(host string) ([]string, error) {
return []string{"::1", "127.0.0.1"}, nil
},
want: []string{fmt.Sprintf("%s:6379", lookback)},
wantErr: "",
},
{
name: "non-loopback address resolving to multiple addresses",
endpoints: fmt.Sprintf("%s:6379", upstream),
lookup: func(host string) ([]string, error) {
return []string{upstream, downstream}, nil
},
want: []string{fmt.Sprintf("%s:6379", upstream), fmt.Sprintf("%s:6379", downstream)},
wantErr: "",
},
{
name: "no such host",
endpoints: fmt.Sprintf("%s:6379", upstream),
lookup: func(host string) ([]string, error) {
return nil, fmt.Errorf("no such host")
},
want: nil,
wantErr: "no such host",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := deriveEndpoints(tt.endpoints, tt.lookup)
if tt.wantErr != "" {
require.ErrorContains(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
assert.Equalf(t, tt.want, got, "failed to derive correct endpoints from %v", tt.endpoints)
})
}
}

0 comments on commit 5265570

Please sign in to comment.