Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

receive: do not leak grpc connection #75

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions pkg/receive/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ type Handler struct {

mtx sync.RWMutex
hashring Hashring
peers *peerGroup
peers peersContainer
expBackoff backoff.Backoff
peerStates map[string]*retryState
receiverMode ReceiverMode
Expand Down Expand Up @@ -252,11 +252,49 @@ func (h *Handler) Hashring(hashring Hashring) {
h.mtx.Lock()
defer h.mtx.Unlock()

if h.hashring != nil {
previousNodes := h.hashring.Nodes()
newNodes := hashring.Nodes()

disappearedNodes := getSortedStringSliceDiff(previousNodes, newNodes)
for _, node := range disappearedNodes {
if err := h.peers.close(node); err != nil {
level.Error(h.logger).Log("msg", "closing gRPC connection failed, we might have leaked a file descriptor", "addr", node, "err", err.Error())
}
}
}

h.hashring = hashring
h.expBackoff.Reset()
h.peerStates = make(map[string]*retryState)
}

// getSortedStringSliceDiff returns items which are in slice1 but not in slice2.
// The returned slice also only contains unique items i.e. it is a set.
func getSortedStringSliceDiff(slice1, slice2 []string) []string {
slice1Items := make(map[string]struct{}, len(slice1))
slice2Items := make(map[string]struct{}, len(slice2))

for _, s1 := range slice1 {
slice1Items[s1] = struct{}{}
}
for _, s2 := range slice2 {
slice2Items[s2] = struct{}{}
}

var difference = make([]string, 0)
for s1 := range slice1Items {
_, s2Contains := slice2Items[s1]
if s2Contains {
continue
}
difference = append(difference, s1)
}
sort.Strings(difference)

return difference
}

// Verifies whether the server is ready or not.
func (h *Handler) isReady() bool {
h.mtx.RLock()
Expand Down Expand Up @@ -1122,46 +1160,67 @@ func newReplicationErrors(threshold, numErrors int) []*replicationErrors {
return errs
}

func newPeerGroup(dialOpts ...grpc.DialOption) *peerGroup {
func newPeerGroup(dialOpts ...grpc.DialOption) peersContainer {
return &peerGroup{
dialOpts: dialOpts,
cache: map[string]storepb.WriteableStoreClient{},
cache: map[string]*grpc.ClientConn{},
m: sync.RWMutex{},
dialer: grpc.DialContext,
}
}

type peersContainer interface {
close(string) error
get(context.Context, string) (storepb.WriteableStoreClient, error)
}

type peerGroup struct {
dialOpts []grpc.DialOption
cache map[string]storepb.WriteableStoreClient
cache map[string]*grpc.ClientConn
m sync.RWMutex

// dialer is used for testing.
dialer func(ctx context.Context, target string, opts ...grpc.DialOption) (conn *grpc.ClientConn, err error)
}

func (p *peerGroup) close(addr string) error {
p.m.Lock()
defer p.m.Unlock()

c, ok := p.cache[addr]
if !ok {
return nil
}

if err := c.Close(); err != nil {
return fmt.Errorf("closing connection for %s", addr)
}

delete(p.cache, addr)
return nil
}

func (p *peerGroup) get(ctx context.Context, addr string) (storepb.WriteableStoreClient, error) {
// use a RLock first to prevent blocking if we don't need to.
p.m.RLock()
c, ok := p.cache[addr]
p.m.RUnlock()
if ok {
return c, nil
return storepb.NewWriteableStoreClient(c), nil
}

p.m.Lock()
defer p.m.Unlock()
// Make sure that another caller hasn't created the connection since obtaining the write lock.
c, ok = p.cache[addr]
if ok {
return c, nil
return storepb.NewWriteableStoreClient(c), nil
}
conn, err := p.dialer(ctx, addr, p.dialOpts...)
if err != nil {
return nil, errors.Wrap(err, "failed to dial peer")
}

client := storepb.NewWriteableStoreClient(conn)
p.cache[addr] = client
return client, nil
p.cache[addr] = conn
return storepb.NewWriteableStoreClient(conn), nil
}
74 changes: 60 additions & 14 deletions pkg/receive/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,38 @@ func (f *fakeAppender) Rollback() error {
return f.rollbackErr()
}

type fakePeersGroup struct {
clients map[string]storepb.WriteableStoreClient

closeCalled map[string]bool
}

func (g *fakePeersGroup) close(addr string) error {
if g.closeCalled == nil {
g.closeCalled = map[string]bool{}
}
g.closeCalled[addr] = true
return nil
}

func (g *fakePeersGroup) get(_ context.Context, addr string) (storepb.WriteableStoreClient, error) {
c, ok := g.clients[addr]
if !ok {
return nil, fmt.Errorf("client %s not found", addr)
}
return c, nil
}

var _ = (peersContainer)(&fakePeersGroup{})

func newTestHandlerHashring(appendables []*fakeAppendable, replicationFactor uint64, hashringAlgo HashringAlgorithm) ([]*Handler, Hashring, error) {
var (
cfg = []HashringConfig{{Hashring: "test"}}
handlers []*Handler
wOpts = &WriterOptions{}
)
// create a fake peer group where we manually fill the cache with fake addresses pointed to our handlers
// This removes the network from the tests and creates a more consistent testing harness.
peers := &peerGroup{
dialOpts: nil,
m: sync.RWMutex{},
cache: map[string]storepb.WriteableStoreClient{},
dialer: func(context.Context, string, ...grpc.DialOption) (*grpc.ClientConn, error) {
// dialer should never be called since we are creating fake clients with fake addresses
// this protects against some leaking test that may attempt to dial random IP addresses
// which may pose a security risk.
return nil, errors.New("unexpected dial called in testing")
},
fakePeers := &fakePeersGroup{
clients: map[string]storepb.WriteableStoreClient{},
}

ag := addrGen{}
Expand All @@ -194,11 +208,11 @@ func newTestHandlerHashring(appendables []*fakeAppendable, replicationFactor uin
Limiter: limiter,
})
handlers = append(handlers, h)
h.peers = peers
addr := ag.newAddr()
h.peers = fakePeers
fakePeers.clients[addr] = &fakeRemoteWriteGRPCServer{h: h}
h.options.Endpoint = addr
cfg[0].Endpoints = append(cfg[0].Endpoints, Endpoint{Address: h.options.Endpoint})
peers.cache[addr] = &fakeRemoteWriteGRPCServer{h: h}
}
// Use hashmod as default.
if hashringAlgo == "" {
Expand Down Expand Up @@ -1569,3 +1583,35 @@ func TestGetStatsLimitParameter(t *testing.T) {
testutil.Equals(t, limit, givenLimit)
})
}

func TestSortedSliceDiff(t *testing.T) {
testutil.Equals(t, []string{"a"}, getSortedStringSliceDiff([]string{"a", "a", "foo"}, []string{"b", "b", "foo"}))
testutil.Equals(t, []string{}, getSortedStringSliceDiff([]string{}, []string{"b", "b", "foo"}))
testutil.Equals(t, []string{}, getSortedStringSliceDiff([]string{}, []string{}))
}

func TestHashringChangeCallsClose(t *testing.T) {
appendables := []*fakeAppendable{
{
appender: newFakeAppender(nil, nil, nil),
},
{
appender: newFakeAppender(nil, nil, nil),
},
{
appender: newFakeAppender(nil, nil, nil),
},
}
allHandlers, _, err := newTestHandlerHashring(appendables, 3, AlgorithmHashmod)
testutil.Ok(t, err)

appendables = appendables[1:]

_, smallHashring, err := newTestHandlerHashring(appendables, 2, AlgorithmHashmod)
testutil.Ok(t, err)

allHandlers[0].Hashring(smallHashring)

pg := allHandlers[0].peers.(*fakePeersGroup)
testutil.Assert(t, len(pg.closeCalled) > 0)
}
31 changes: 31 additions & 0 deletions pkg/receive/hashring.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ type Hashring interface {
Get(tenant string, timeSeries *prompb.TimeSeries) (string, error)
// GetN returns the nth node that should handle the given tenant and time series.
GetN(tenant string, timeSeries *prompb.TimeSeries, n uint64) (string, error)
// Nodes returns a sorted slice of nodes that are in this hashring. Addresses could be duplicated
// if, for example, the same address is used for multiple tenants in the multi-hashring.
Nodes() []string
}

// SingleNodeHashring always returns the same node.
Expand All @@ -65,6 +68,10 @@ func (s SingleNodeHashring) Get(tenant string, ts *prompb.TimeSeries) (string, e
return s.GetN(tenant, ts, 0)
}

func (s SingleNodeHashring) Nodes() []string {
return []string{string(s)}
}

// GetN implements the Hashring interface.
func (s SingleNodeHashring) GetN(_ string, _ *prompb.TimeSeries, n uint64) (string, error) {
if n > 0 {
Expand All @@ -84,9 +91,15 @@ func newSimpleHashring(endpoints []Endpoint) (Hashring, error) {
}
addresses[i] = endpoints[i].Address
}
sort.Strings(addresses)

return simpleHashring(addresses), nil
}

func (s simpleHashring) Nodes() []string {
return s
}

// Get returns a target to handle the given tenant and time series.
func (s simpleHashring) Get(tenant string, ts *prompb.TimeSeries) (string, error) {
return s.GetN(tenant, ts, 0)
Expand Down Expand Up @@ -120,6 +133,7 @@ type ketamaHashring struct {
endpoints []Endpoint
sections sections
numEndpoints uint64
nodes []string
}

func newKetamaHashring(endpoints []Endpoint, sectionsPerNode int, replicationFactor uint64) (*ketamaHashring, error) {
Expand All @@ -132,8 +146,11 @@ func newKetamaHashring(endpoints []Endpoint, sectionsPerNode int, replicationFac
hash := xxhash.New()
availabilityZones := make(map[string]struct{})
ringSections := make(sections, 0, numSections)

nodes := []string{}
for endpointIndex, endpoint := range endpoints {
availabilityZones[endpoint.AZ] = struct{}{}
nodes = append(nodes, endpoint.Address)
for i := 1; i <= sectionsPerNode; i++ {
_, _ = hash.Write([]byte(endpoint.Address + ":" + strconv.Itoa(i)))
n := &section{
Expand All @@ -148,15 +165,21 @@ func newKetamaHashring(endpoints []Endpoint, sectionsPerNode int, replicationFac
}
}
sort.Sort(ringSections)
sort.Strings(nodes)
calculateSectionReplicas(ringSections, replicationFactor, availabilityZones)

return &ketamaHashring{
endpoints: endpoints,
sections: ringSections,
numEndpoints: uint64(len(endpoints)),
nodes: nodes,
}, nil
}

func (k *ketamaHashring) Nodes() []string {
return k.nodes
}

func sizeOfLeastOccupiedAZ(azSpread map[string]int64) int64 {
minValue := int64(math.MaxInt64)
for _, value := range azSpread {
Expand Down Expand Up @@ -232,6 +255,8 @@ type multiHashring struct {
// to the cache map, as this is both written to
// and read from.
mu sync.RWMutex

nodes []string
}

// Get returns a target to handle the given tenant and time series.
Expand Down Expand Up @@ -269,6 +294,10 @@ func (m *multiHashring) GetN(tenant string, ts *prompb.TimeSeries, n uint64) (st
return "", errors.New("no matching hashring to handle tenant")
}

func (m *multiHashring) Nodes() []string {
return m.nodes
}

// newMultiHashring creates a multi-tenant hashring for a given slice of
// groups.
// Which hashring to use for a tenant is determined
Expand All @@ -289,6 +318,7 @@ func NewMultiHashring(algorithm HashringAlgorithm, replicationFactor uint64, cfg
if err != nil {
return nil, err
}
m.nodes = append(m.nodes, hashring.Nodes()...)
m.hashrings = append(m.hashrings, hashring)
var t map[string]struct{}
if len(h.Tenants) != 0 {
Expand All @@ -299,6 +329,7 @@ func NewMultiHashring(algorithm HashringAlgorithm, replicationFactor uint64, cfg
}
m.tenantSets = append(m.tenantSets, t)
}
sort.Strings(m.nodes)
return m, nil
}

Expand Down
Loading