Skip to content

Commit

Permalink
adding AddressTranslator interface and impl for use in ec2
Browse files Browse the repository at this point in the history
This change introduces the AddressTranslator interface, which is
intended to translate peer addresses just before creating a connection
to those nodes. The primary use -- which is driving the change -- is
to be able to translate public IPs to private IPs in ec2.

This solution is common among other CQL driver implementations. The
specific implementation here also follows the convention set by
HostFilter.

Signed-off-by: Justin "Gus" Knowlden <gus@gusg.us>
  • Loading branch information
Charles Frantz authored and gus committed Nov 3, 2016
1 parent 3a52a1d commit d93ce32
Show file tree
Hide file tree
Showing 13 changed files with 325 additions and 9 deletions.
26 changes: 26 additions & 0 deletions address_translators.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package gocql

import "net"

// AddressTranslator provides a way to translate node addresses (and ports) that are
// discovered or received as a node event. This can be useful in an ec2 environment,
// for instance, to translate public IPs to private IPs.
type AddressTranslator interface {
// Translate will translate the provided address and/or port to another
// address and/or port. If no translation is possible, Translate will return the
// address and port provided to it.
Translate(addr net.IP, port int) (net.IP, int)
}

type AddressTranslatorFunc func(addr net.IP, port int) (net.IP, int)

func (fn AddressTranslatorFunc) Translate(addr net.IP, port int) (net.IP, int) {
return fn(addr, port)
}

// IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op.
func IdentityTranslator() AddressTranslator {
return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
return addr, port
})
}
34 changes: 34 additions & 0 deletions address_translators_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package gocql

import (
"net"
"testing"
)

func TestIdentityAddressTranslator_NilAddrAndZeroPort(t *testing.T) {
var tr AddressTranslator = IdentityTranslator()
hostIP := net.ParseIP("")
if hostIP != nil {
t.Errorf("expected host ip to be (nil) but was (%+v) instead", hostIP)
}

addr, port := tr.Translate(hostIP, 0)
if addr != nil {
t.Errorf("expected translated host to be (nil) but was (%+v) instead", addr)
}
assertEqual(t, "translated port", 0, port)
}

func TestIdentityAddressTranslator_HostProvided(t *testing.T) {
var tr AddressTranslator = IdentityTranslator()
hostIP := net.ParseIP("10.1.2.3")
if hostIP == nil {
t.Error("expected host ip not to be (nil)")
}

addr, port := tr.Translate(hostIP, 9042)
if !hostIP.Equal(addr) {
t.Errorf("expected translated addr to be (%+v) but was (%+v) instead", hostIP, addr)
}
assertEqual(t, "translated port", 9042, port)
}
21 changes: 21 additions & 0 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package gocql

import (
"errors"
"log"
"net"
"time"
)

Expand Down Expand Up @@ -75,6 +77,10 @@ type ClusterConfig struct {
// via Discovery
HostFilter HostFilter

// AddressTranslator will translate addresses found on peer discovery and/or
// node change events.
AddressTranslator AddressTranslator

// If IgnorePeerAddr is true and the address in system.peers does not match
// the supplied host by either initial hosts or discovered via events then the
// host will be replaced with the supplied address.
Expand Down Expand Up @@ -146,6 +152,21 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
return NewSession(*cfg)
}

// translateAddressPort is a helper method that will use the given AddressTranslator
// if defined, to translate the given address and port into a possibly new address
// and port, If no AddressTranslator or if an error occurs, the given address and
// port will be returned.
func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, int) {
if cfg.AddressTranslator == nil || len(addr) == 0 {
return addr, port
}
newAddr, newPort := cfg.AddressTranslator.Translate(addr, port)
if gocqlDebug {
log.Printf("gocql: translating address '%v:%d' to '%v:%d'", addr, port, newAddr, newPort)
}
return newAddr, newPort
}

var (
ErrNoHosts = errors.New("no hosts provided")
ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")
Expand Down
53 changes: 53 additions & 0 deletions cluster_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package gocql

import (
"testing"
"time"
"net"
)

func TestNewCluster_Defaults(t *testing.T) {
cfg := NewCluster()
assertEqual(t, "cluster config cql version", "3.0.0", cfg.CQLVersion)
assertEqual(t, "cluster config timeout", 600*time.Millisecond, cfg.Timeout)
assertEqual(t, "cluster config port", 9042, cfg.Port)
assertEqual(t, "cluster config num-conns", 2, cfg.NumConns)
assertEqual(t, "cluster config consistency", Quorum, cfg.Consistency)
assertEqual(t, "cluster config max prepared statements", defaultMaxPreparedStmts, cfg.MaxPreparedStmts)
assertEqual(t, "cluster config max routing key info", 1000, cfg.MaxRoutingKeyInfo)
assertEqual(t, "cluster config page-size", 5000, cfg.PageSize)
assertEqual(t, "cluster config default timestamp", true, cfg.DefaultTimestamp)
assertEqual(t, "cluster config max wait schema agreement", 60*time.Second, cfg.MaxWaitSchemaAgreement)
assertEqual(t, "cluster config reconnect interval", 60*time.Second, cfg.ReconnectInterval)
}

func TestNewCluster_WithHosts(t *testing.T) {
cfg := NewCluster("addr1", "addr2")
assertEqual(t, "cluster config hosts length", 2, len(cfg.Hosts))
assertEqual(t, "cluster config host 0", "addr1", cfg.Hosts[0])
assertEqual(t, "cluster config host 1", "addr2", cfg.Hosts[1])
}

func TestClusterConfig_translateAddressAndPort_NilTranslator(t *testing.T) {
cfg := NewCluster()
assertNil(t, "cluster config address translator", cfg.AddressTranslator)
newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 1234)
assertTrue(t, "same address as provided", net.ParseIP("10.0.0.1").Equal(newAddr))
assertEqual(t, "translated host and port", 1234, newPort)
}

func TestClusterConfig_translateAddressAndPort_EmptyAddr(t *testing.T) {
cfg := NewCluster()
cfg.AddressTranslator = staticAddressTranslator(net.ParseIP("10.10.10.10"), 5432)
newAddr, newPort := cfg.translateAddressPort(net.IP([]byte{}), 0)
assertTrue(t, "translated address is still empty", len(newAddr) == 0)
assertEqual(t, "translated port", 0, newPort)
}

func TestClusterConfig_translateAddressAndPort_Success(t *testing.T) {
cfg := NewCluster()
cfg.AddressTranslator = staticAddressTranslator(net.ParseIP("10.10.10.10"), 5432)
newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 2345)
assertTrue(t, "translated address", net.ParseIP("10.10.10.10").Equal(newAddr))
assertEqual(t, "translated port", 5432, newPort)
}
51 changes: 51 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sync"
"testing"
"time"
"net"
)

var (
Expand Down Expand Up @@ -143,3 +144,53 @@ func createSession(tb testing.TB) *Session {
cluster := createCluster()
return createSessionFromCluster(cluster, tb)
}

// createTestSession is hopefully moderately useful in actual unit tests
func createTestSession() *Session {
config := NewCluster()
config.NumConns = 1
config.Timeout = 0
config.DisableInitialHostLookup = true
config.IgnorePeerAddr = true
config.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
session := &Session{
cfg: *config,
connCfg: &ConnConfig{
Timeout: 10*time.Millisecond,
Keepalive: 0,
},
policy: config.PoolConfig.HostSelectionPolicy,
}
session.pool = config.PoolConfig.buildPool(session)
return session
}

func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator {
return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
return newAddr, newPort
})
}

func assertTrue(t *testing.T, description string, value bool) {
if !value {
t.Errorf("expected %s to be true", description)
}
}

func assertEqual(t *testing.T, description string, expected, actual interface{}) {
if expected != actual {
t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
}
}

func assertNil(t *testing.T, description string, actual interface{}) {
if actual != nil {
t.Errorf("expected %s to be (nil) but was (%+v) instead", description, actual)
}
}

func assertNotNil(t *testing.T, description string, actual interface{}) {
if actual == nil {
t.Errorf("expected %s not to be (nil)", description)
}
}
4 changes: 3 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, ses
}

// TODO(zariel): handle ipv6 zone
addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()
translatedPeer, translatedPort := session.cfg.translateAddressPort(host.Peer(), host.Port())
addr := (&net.TCPAddr{IP: translatedPeer, Port: translatedPort}).String()
//addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String()

if cfg.tlsConfig != nil {
// the TLS config is safe to be reused by connections but it must not
Expand Down
4 changes: 2 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ func TestStream0(t *testing.T) {
}
})

conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -509,7 +509,7 @@ func TestConnClosedBlocked(t *testing.T) {
t.Log(err)
})

conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession())
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ func (c *controlConn) reconnect(refreshring bool) {
}
}

// TODO: should have our own roundrobbin for hosts so that we can try each
// in succession and guantee that we get a different host each time.
// TODO: should have our own round-robin for hosts so that we can try each
// in succession and guarantee that we get a different host each time.
if newConn == nil {
host := c.session.ring.rrHost()
if host == nil {
Expand Down
1 change: 0 additions & 1 deletion events.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
s.pool.addHost(hostInfo)
s.policy.AddHost(hostInfo)
hostInfo.setState(NodeUp)

if s.control != nil && !s.cfg.IgnorePeerAddr {
s.hostSource.refreshRing()
}
Expand Down
2 changes: 1 addition & 1 deletion integration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function run_tests() {

local args="-gocql.timeout=60s -runssl -proto=$proto -rf=3 -clusterSize=$clusterSize -autowait=2000ms -compressor=snappy -gocql.cversion=$version -cluster=$(ccm liveset) ./..."

go test -v -tags unit
go test -v -tags unit

if [ "$auth" = true ]
then
Expand Down
Loading

0 comments on commit d93ce32

Please sign in to comment.