From 789e9fa80849e831763605fd0bb2b0476f4e759a Mon Sep 17 00:00:00 2001 From: Joshua Humphries Date: Wed, 26 Jul 2017 13:47:03 -0400 Subject: [PATCH] max buffer size (#166) * max buffer size * add TestMaxBufferSize --- zk/conn.go | 40 ++++++++++++- zk/zk_test.go | 159 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+), 1 deletion(-) diff --git a/zk/conn.go b/zk/conn.go index b6b8dbc1..89dbf5fa 100644 --- a/zk/conn.go +++ b/zk/conn.go @@ -85,6 +85,7 @@ type Conn struct { pingInterval time.Duration recvTimeout time.Duration connectTimeout time.Duration + maxBufferSize int creds []authCreds credsMu sync.Mutex // protects server @@ -249,6 +250,36 @@ func WithEventCallback(cb EventCallback) connOption { } } +// WithMaxBufferSize sets the maximum buffer size used to read and decode +// packets received from the Zookeeper server. The standard Zookeeper client for +// Java defaults to a limit of 1mb. For backwards compatibility, this Go client +// defaults to unbounded unless overridden via this option. A value that is zero +// or negative indicates that no limit is enforced. +// +// This is meant to prevent resource exhaustion in the face of potentially +// malicious data in ZK. It should generally match the server setting (which +// also defaults ot 1mb) so that clients and servers agree on the limits for +// things like the size of data in an individual znode and the total size of a +// transaction. +// +// For production systems, this should be set to a reasonable value (ideally +// that matches the server configuration). For ops tooling, it is handy to use a +// much larger limit, in order to do things like clean-up problematic state in +// the ZK tree. For example, if a single znode has a huge number of children, it +// is possible for the response to a "list children" operation to exceed this +// buffer size and cause errors in clients. The only way to subsequently clean +// up the tree (by removing superfluous children) is to use a client configured +// with a larger buffer size that can successfully query for all of the child +// names and then remove them. (Note there are other tools that can list all of +// the child names without an increased buffer size in the client, but they work +// by inspecting the servers' transaction logs to enumerate children instead of +// sending an online request to a server. +func WithMaxBufferSize(maxBufferSize int) connOption { + return func(c *Conn) { + c.maxBufferSize = maxBufferSize + } +} + func (c *Conn) Close() { close(c.shouldQuit) @@ -676,7 +707,11 @@ func (c *Conn) sendLoop() error { } func (c *Conn) recvLoop(conn net.Conn) error { - buf := make([]byte, bufferSize) + sz := bufferSize + if c.maxBufferSize > 0 && sz > c.maxBufferSize { + sz = c.maxBufferSize + } + buf := make([]byte, sz) for { // package length conn.SetReadDeadline(time.Now().Add(c.recvTimeout)) @@ -687,6 +722,9 @@ func (c *Conn) recvLoop(conn net.Conn) error { blen := int(binary.BigEndian.Uint32(buf[:4])) if cap(buf) < blen { + if c.maxBufferSize > 0 && blen > c.maxBufferSize { + return fmt.Errorf("received packet from server with length %d, which exceeds max buffer size %d", blen, c.maxBufferSize) + } buf = make([]byte, blen) } diff --git a/zk/zk_test.go b/zk/zk_test.go index 781cdd41..4b8f78d4 100644 --- a/zk/zk_test.go +++ b/zk/zk_test.go @@ -6,7 +6,11 @@ import ( "fmt" "io" "net" + "reflect" + "regexp" + "sort" "strings" + "sync" "testing" "time" ) @@ -716,3 +720,158 @@ func startSlowProxy(t *testing.T, up, down Rate, upstream string, adj func(ln *L }() return ln.Addr().String(), stopCh, nil } + +func TestMaxBufferSize(t *testing.T) { + ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + defer ts.Stop() + // no buffer size + zk, _, err := ts.ConnectWithOptions(15 * time.Second) + var l testLogger + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + defer zk.Close() + // 1k buffer size, logs to custom test logger + zkLimited, _, err := ts.ConnectWithOptions(15*time.Second, WithMaxBufferSize(1024), func(conn *Conn) { + conn.SetLogger(&l) + }) + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + defer zkLimited.Close() + + // With small node with small number of children + data := []byte{101, 102, 103, 103} + _, err = zk.Create("/foo", data, 0, WorldACL(PermAll)) + if err != nil { + t.Fatalf("Create returned error: %+v", err) + } + var children []string + for i := 0; i < 4; i++ { + childName, err := zk.Create("/foo/child", nil, FlagEphemeral|FlagSequence, WorldACL(PermAll)) + if err != nil { + t.Fatalf("Create returned error: %+v", err) + } + children = append(children, childName[len("/foo/"):]) // strip parent prefix from name + } + sort.Strings(children) + + // Limited client works fine + resultData, _, err := zkLimited.Get("/foo") + if err != nil { + t.Fatalf("Get returned error: %+v", err) + } + if !reflect.DeepEqual(resultData, data) { + t.Fatalf("Get returned unexpected data; expecting %+v, got %+v", data, resultData) + } + resultChildren, _, err := zkLimited.Children("/foo") + if err != nil { + t.Fatalf("Children returned error: %+v", err) + } + sort.Strings(resultChildren) + if !reflect.DeepEqual(resultChildren, children) { + t.Fatalf("Children returned unexpected names; expecting %+v, got %+v", children, resultChildren) + } + + // With large node though... + data = make([]byte, 1024) + for i := 0; i < 1024; i++ { + data[i] = byte(i) + } + _, err = zk.Create("/bar", data, 0, WorldACL(PermAll)) + if err != nil { + t.Fatalf("Create returned error: %+v", err) + } + _, _, err = zkLimited.Get("/bar") + // NB: Sadly, without actually de-serializing the too-large response packet, we can't send the + // right error to the corresponding outstanding request. So the request just sees ErrConnectionClosed + // while the log will see the actual reason the connection was closed. + expectErr(t, err, ErrConnectionClosed) + expectLogMessage(t, &l, "received packet from server with length .*, which exceeds max buffer size 1024") + + // Or with large number of children... + totalLen := 0 + children = nil + for totalLen < 1024 { + childName, err := zk.Create("/bar/child", nil, FlagEphemeral|FlagSequence, WorldACL(PermAll)) + if err != nil { + t.Fatalf("Create returned error: %+v", err) + } + n := childName[len("/bar/"):] // strip parent prefix from name + children = append(children, n) + totalLen += len(n) + } + sort.Strings(children) + _, _, err = zkLimited.Children("/bar") + expectErr(t, err, ErrConnectionClosed) + expectLogMessage(t, &l, "received packet from server with length .*, which exceeds max buffer size 1024") + + // Other client (without buffer size limit) can successfully query the node and its children, of course + resultData, _, err = zk.Get("/bar") + if err != nil { + t.Fatalf("Get returned error: %+v", err) + } + if !reflect.DeepEqual(resultData, data) { + t.Fatalf("Get returned unexpected data; expecting %+v, got %+v", data, resultData) + } + resultChildren, _, err = zk.Children("/bar") + if err != nil { + t.Fatalf("Children returned error: %+v", err) + } + sort.Strings(resultChildren) + if !reflect.DeepEqual(resultChildren, children) { + t.Fatalf("Children returned unexpected names; expecting %+v, got %+v", children, resultChildren) + } +} + +func expectErr(t *testing.T, err error, expected error) { + if err == nil { + t.Fatalf("Get for node that is too large should have returned error!") + } + if err != expected { + t.Fatalf("Get returned wrong error; expecting ErrClosing, got %+v", err) + } +} + +func expectLogMessage(t *testing.T, logger *testLogger, pattern string) { + re := regexp.MustCompile(pattern) + events := logger.Reset() + if len(events) == 0 { + t.Fatalf("Failed to log error; expecting message that matches pattern: %s", pattern) + } + var found []string + for _, e := range events { + if re.Match([]byte(e)) { + found = append(found, e) + } + } + if len(found) == 0 { + t.Fatalf("Failed to log error; expecting message that matches pattern: %s", pattern) + } else if len(found) > 1 { + t.Fatalf("Logged error redundantly %d times:\n%+v", len(found), found) + } +} + +type testLogger struct { + mu sync.Mutex + events []string +} + +func (l *testLogger) Printf(msgFormat string, args ...interface{}) { + msg := fmt.Sprintf(msgFormat, args...) + fmt.Println(msg) + l.mu.Lock() + defer l.mu.Unlock() + l.events = append(l.events, msg) +} + +func (l *testLogger) Reset() []string { + l.mu.Lock() + defer l.mu.Unlock() + ret := l.events + l.events = nil + return ret +}