Skip to content

Commit

Permalink
max buffer size (samuel#166)
Browse files Browse the repository at this point in the history
* max buffer size

* add TestMaxBufferSize
  • Loading branch information
jhump authored and samuel committed Jul 26, 2017
1 parent 1d7be4e commit 789e9fa
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 1 deletion.
40 changes: 39 additions & 1 deletion zk/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type Conn struct {
pingInterval time.Duration
recvTimeout time.Duration
connectTimeout time.Duration
maxBufferSize int

creds []authCreds
credsMu sync.Mutex // protects server
Expand Down Expand Up @@ -249,6 +250,36 @@ func WithEventCallback(cb EventCallback) connOption {
}
}

// WithMaxBufferSize sets the maximum buffer size used to read and decode
// packets received from the Zookeeper server. The standard Zookeeper client for
// Java defaults to a limit of 1mb. For backwards compatibility, this Go client
// defaults to unbounded unless overridden via this option. A value that is zero
// or negative indicates that no limit is enforced.
//
// This is meant to prevent resource exhaustion in the face of potentially
// malicious data in ZK. It should generally match the server setting (which
// also defaults ot 1mb) so that clients and servers agree on the limits for
// things like the size of data in an individual znode and the total size of a
// transaction.
//
// For production systems, this should be set to a reasonable value (ideally
// that matches the server configuration). For ops tooling, it is handy to use a
// much larger limit, in order to do things like clean-up problematic state in
// the ZK tree. For example, if a single znode has a huge number of children, it
// is possible for the response to a "list children" operation to exceed this
// buffer size and cause errors in clients. The only way to subsequently clean
// up the tree (by removing superfluous children) is to use a client configured
// with a larger buffer size that can successfully query for all of the child
// names and then remove them. (Note there are other tools that can list all of
// the child names without an increased buffer size in the client, but they work
// by inspecting the servers' transaction logs to enumerate children instead of
// sending an online request to a server.
func WithMaxBufferSize(maxBufferSize int) connOption {
return func(c *Conn) {
c.maxBufferSize = maxBufferSize
}
}

func (c *Conn) Close() {
close(c.shouldQuit)

Expand Down Expand Up @@ -676,7 +707,11 @@ func (c *Conn) sendLoop() error {
}

func (c *Conn) recvLoop(conn net.Conn) error {
buf := make([]byte, bufferSize)
sz := bufferSize
if c.maxBufferSize > 0 && sz > c.maxBufferSize {
sz = c.maxBufferSize
}
buf := make([]byte, sz)
for {
// package length
conn.SetReadDeadline(time.Now().Add(c.recvTimeout))
Expand All @@ -687,6 +722,9 @@ func (c *Conn) recvLoop(conn net.Conn) error {

blen := int(binary.BigEndian.Uint32(buf[:4]))
if cap(buf) < blen {
if c.maxBufferSize > 0 && blen > c.maxBufferSize {
return fmt.Errorf("received packet from server with length %d, which exceeds max buffer size %d", blen, c.maxBufferSize)
}
buf = make([]byte, blen)
}

Expand Down
159 changes: 159 additions & 0 deletions zk/zk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import (
"fmt"
"io"
"net"
"reflect"
"regexp"
"sort"
"strings"
"sync"
"testing"
"time"
)
Expand Down Expand Up @@ -716,3 +720,158 @@ func startSlowProxy(t *testing.T, up, down Rate, upstream string, adj func(ln *L
}()
return ln.Addr().String(), stopCh, nil
}

func TestMaxBufferSize(t *testing.T) {
ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "})
if err != nil {
t.Fatal(err)
}
defer ts.Stop()
// no buffer size
zk, _, err := ts.ConnectWithOptions(15 * time.Second)
var l testLogger
if err != nil {
t.Fatalf("Connect returned error: %+v", err)
}
defer zk.Close()
// 1k buffer size, logs to custom test logger
zkLimited, _, err := ts.ConnectWithOptions(15*time.Second, WithMaxBufferSize(1024), func(conn *Conn) {
conn.SetLogger(&l)
})
if err != nil {
t.Fatalf("Connect returned error: %+v", err)
}
defer zkLimited.Close()

// With small node with small number of children
data := []byte{101, 102, 103, 103}
_, err = zk.Create("/foo", data, 0, WorldACL(PermAll))
if err != nil {
t.Fatalf("Create returned error: %+v", err)
}
var children []string
for i := 0; i < 4; i++ {
childName, err := zk.Create("/foo/child", nil, FlagEphemeral|FlagSequence, WorldACL(PermAll))
if err != nil {
t.Fatalf("Create returned error: %+v", err)
}
children = append(children, childName[len("/foo/"):]) // strip parent prefix from name
}
sort.Strings(children)

// Limited client works fine
resultData, _, err := zkLimited.Get("/foo")
if err != nil {
t.Fatalf("Get returned error: %+v", err)
}
if !reflect.DeepEqual(resultData, data) {
t.Fatalf("Get returned unexpected data; expecting %+v, got %+v", data, resultData)
}
resultChildren, _, err := zkLimited.Children("/foo")
if err != nil {
t.Fatalf("Children returned error: %+v", err)
}
sort.Strings(resultChildren)
if !reflect.DeepEqual(resultChildren, children) {
t.Fatalf("Children returned unexpected names; expecting %+v, got %+v", children, resultChildren)
}

// With large node though...
data = make([]byte, 1024)
for i := 0; i < 1024; i++ {
data[i] = byte(i)
}
_, err = zk.Create("/bar", data, 0, WorldACL(PermAll))
if err != nil {
t.Fatalf("Create returned error: %+v", err)
}
_, _, err = zkLimited.Get("/bar")
// NB: Sadly, without actually de-serializing the too-large response packet, we can't send the
// right error to the corresponding outstanding request. So the request just sees ErrConnectionClosed
// while the log will see the actual reason the connection was closed.
expectErr(t, err, ErrConnectionClosed)
expectLogMessage(t, &l, "received packet from server with length .*, which exceeds max buffer size 1024")

// Or with large number of children...
totalLen := 0
children = nil
for totalLen < 1024 {
childName, err := zk.Create("/bar/child", nil, FlagEphemeral|FlagSequence, WorldACL(PermAll))
if err != nil {
t.Fatalf("Create returned error: %+v", err)
}
n := childName[len("/bar/"):] // strip parent prefix from name
children = append(children, n)
totalLen += len(n)
}
sort.Strings(children)
_, _, err = zkLimited.Children("/bar")
expectErr(t, err, ErrConnectionClosed)
expectLogMessage(t, &l, "received packet from server with length .*, which exceeds max buffer size 1024")

// Other client (without buffer size limit) can successfully query the node and its children, of course
resultData, _, err = zk.Get("/bar")
if err != nil {
t.Fatalf("Get returned error: %+v", err)
}
if !reflect.DeepEqual(resultData, data) {
t.Fatalf("Get returned unexpected data; expecting %+v, got %+v", data, resultData)
}
resultChildren, _, err = zk.Children("/bar")
if err != nil {
t.Fatalf("Children returned error: %+v", err)
}
sort.Strings(resultChildren)
if !reflect.DeepEqual(resultChildren, children) {
t.Fatalf("Children returned unexpected names; expecting %+v, got %+v", children, resultChildren)
}
}

func expectErr(t *testing.T, err error, expected error) {
if err == nil {
t.Fatalf("Get for node that is too large should have returned error!")
}
if err != expected {
t.Fatalf("Get returned wrong error; expecting ErrClosing, got %+v", err)
}
}

func expectLogMessage(t *testing.T, logger *testLogger, pattern string) {
re := regexp.MustCompile(pattern)
events := logger.Reset()
if len(events) == 0 {
t.Fatalf("Failed to log error; expecting message that matches pattern: %s", pattern)
}
var found []string
for _, e := range events {
if re.Match([]byte(e)) {
found = append(found, e)
}
}
if len(found) == 0 {
t.Fatalf("Failed to log error; expecting message that matches pattern: %s", pattern)
} else if len(found) > 1 {
t.Fatalf("Logged error redundantly %d times:\n%+v", len(found), found)
}
}

type testLogger struct {
mu sync.Mutex
events []string
}

func (l *testLogger) Printf(msgFormat string, args ...interface{}) {
msg := fmt.Sprintf(msgFormat, args...)
fmt.Println(msg)
l.mu.Lock()
defer l.mu.Unlock()
l.events = append(l.events, msg)
}

func (l *testLogger) Reset() []string {
l.mu.Lock()
defer l.mu.Unlock()
ret := l.events
l.events = nil
return ret
}

0 comments on commit 789e9fa

Please sign in to comment.