From 94ea195dc19232dbf5499b479901a7bb0989afa0 Mon Sep 17 00:00:00 2001 From: Dimitrij Denissenko Date: Thu, 29 Jun 2017 22:43:19 +0100 Subject: [PATCH] Use node address instead of relying on loopback reported by redis --- cluster.go | 27 ++++++++++++++++++++++++--- commands_test.go | 4 ++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/cluster.go b/cluster.go index 51e94a752..62dfe4a5a 100644 --- a/cluster.go +++ b/cluster.go @@ -3,6 +3,7 @@ package redis import ( "fmt" "math/rand" + "net" "sync" "sync/atomic" "time" @@ -244,16 +245,22 @@ type clusterState struct { slots [][]*clusterNode } -func newClusterState(nodes *clusterNodes, slots []ClusterSlot) (*clusterState, error) { +func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (*clusterState, error) { c := clusterState{ nodes: nodes, slots: make([][]*clusterNode, hashtag.SlotNumber), } + isLoopbackOrigin := isLoopbackAddr(origin) for _, slot := range slots { var nodes []*clusterNode for _, slotNode := range slot.Nodes { - node, err := c.nodes.Get(slotNode.Addr) + addr := slotNode.Addr + if !isLoopbackOrigin && isLoopbackAddr(addr) { + addr = origin + } + + node, err := c.nodes.Get(addr) if err != nil { return nil, err } @@ -661,7 +668,7 @@ func (c *ClusterClient) reloadSlots() (*clusterState, error) { return nil, err } - return newClusterState(c.nodes, slots) + return newClusterState(c.nodes, slots, node.Client.opt.Addr) } // reaper closes idle connections to the cluster. @@ -960,3 +967,17 @@ func (c *ClusterClient) txPipelineReadQueued( return firstErr } + +func isLoopbackAddr(addr string) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return false + } + + ip := net.ParseIP(host) + if ip == nil { + return false + } + + return ip.IsLoopback() +} diff --git a/commands_test.go b/commands_test.go index 64a50b50d..e8cdb205e 100644 --- a/commands_test.go +++ b/commands_test.go @@ -2888,12 +2888,12 @@ var _ = Describe("Commands", func() { It("returns map of commands", func() { cmds, err := client.Command().Result() Expect(err).NotTo(HaveOccurred()) - Expect(len(cmds)).To(BeNumerically("~", 173, 5)) + Expect(len(cmds)).To(BeNumerically("~", 180, 10)) cmd := cmds["mget"] Expect(cmd.Name).To(Equal("mget")) Expect(cmd.Arity).To(Equal(int8(-2))) - Expect(cmd.Flags).To(Equal([]string{"readonly"})) + Expect(cmd.Flags).To(ContainElement("readonly")) Expect(cmd.FirstKeyPos).To(Equal(int8(1))) Expect(cmd.LastKeyPos).To(Equal(int8(-1))) Expect(cmd.StepCount).To(Equal(int8(1)))