Skip to content

Commit

Permalink
pool wire write buffer (#2799)
Browse files Browse the repository at this point in the history
* rough cut buf

* edits

* edits

* correctness for byte copying

* fix bugs

* simplify

* correct simplification

* page sized spool buffer

* fix build

* comments

* bump timeout

* bump timeout

* fix race

* try separate sleep error

* vitess bump

* see if sleep error masks a different error

* add sleeps back

* more error check where it won't hide other errors

* remove handler test race

* revert back to racey with sleeps

* zach comments

* [ga-format-pr] Run ./format_repo.sh to fix formatting

---------

Co-authored-by: max-hoffman <max-hoffman@users.noreply.github.com>
  • Loading branch information
max-hoffman and max-hoffman authored Dec 20, 2024
1 parent e44b780 commit 999a371
Show file tree
Hide file tree
Showing 16 changed files with 209 additions and 38 deletions.
3 changes: 2 additions & 1 deletion enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -5743,6 +5743,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
require.NoError(t, err)
expectedRowSet := script.Results[queryIdx]
expectedRowIdx := 0
buf := sql.NewByteBuffer(1000)
var engineRow sql.Row
for engineRow, err = engineIter.Next(ctx); err == nil; engineRow, err = engineIter.Next(ctx) {
if !assert.True(t, r.Next()) {
Expand All @@ -5760,7 +5761,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
break
}
expectedEngineRow := make([]*string, len(engineRow))
row, err := server.RowToSQL(ctx, sch, engineRow, nil)
row, err := server.RowToSQL(ctx, sch, engineRow, nil, buf)
if !assert.NoError(t, err) {
break
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54
github.com/dolthub/vitess v0.0.0-20241220202600-b18f18d0cde7
github.com/go-kit/kit v0.10.0
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
github.com/gocraft/dbr/v2 v2.7.2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54 h1:nzBnC0Rt1gFtscJEz4veYd/mazZEdbdmed+tujdaKOo=
github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
github.com/dolthub/vitess v0.0.0-20241220202600-b18f18d0cde7 h1:w130WLeARGGNYWmhGPugsHXzJEelKKimt3kTWg6/Puk=
github.com/dolthub/vitess v0.0.0-20241220202600-b18f18d0cde7/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
Expand Down
42 changes: 25 additions & 17 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,21 @@ func (h *Handler) doQuery(
var r *sqltypes.Result
var processedAtLeastOneBatch bool

buf := sql.ByteBufPool.Get().(*sql.ByteBuffer)
defer func() {
buf.Reset()
sql.ByteBufPool.Put(buf)
}()

// zero/single return schema use spooling shortcut
if types.IsOkResultSchema(schema) {
r, err = resultForOkIter(sqlCtx, rowIter)
} else if schema == nil {
r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields)
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields)
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf)
} else {
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more)
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf)
}
if err != nil {
return remainder, err
Expand Down Expand Up @@ -542,7 +548,7 @@ func GetDeferredProjections(iter sql.RowIter) (sql.RowIter, []sql.Expression) {
}

// resultForMax1RowIter ensures that an empty iterator returns at most one row
func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field) (*sqltypes.Result, error) {
func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field, buf *sql.ByteBuffer) (*sqltypes.Result, error) {
defer trace.StartRegion(ctx, "Handler.resultForMax1RowIter").End()
row, err := iter.Next(ctx)
if err == io.EOF {
Expand All @@ -557,7 +563,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
if err := iter.Close(ctx); err != nil {
return nil, err
}
outputRow, err := RowToSQL(ctx, schema, row, nil)
outputRow, err := RowToSQL(ctx, schema, row, nil, buf)
if err != nil {
return nil, err
}
Expand All @@ -569,14 +575,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,

// resultForDefaultIter reads batches of rows from the iterator
// and writes results into the callback function.
func (h *Handler) resultForDefaultIter(
ctx *sql.Context,
c *mysql.Conn,
schema sql.Schema,
iter sql.RowIter,
callback func(*sqltypes.Result, bool) error,
resultFields []*querypb.Field,
more bool) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) {
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) {
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End()

eg, ctx := ctx.NewErrgroup()
Expand Down Expand Up @@ -669,7 +668,7 @@ func (h *Handler) resultForDefaultIter(
continue
}

outputRow, err := RowToSQL(ctx, schema, row, projs)
outputRow, err := RowToSQL(ctx, schema, row, projs, buf)
if err != nil {
return err
}
Expand Down Expand Up @@ -932,21 +931,30 @@ func updateMaxUsedConnectionsStatusVariable() {
}()
}

func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression) ([]sqltypes.Value, error) {
func toSqlHelper(ctx *sql.Context, typ sql.Type, buf *sql.ByteBuffer, val interface{}) (sqltypes.Value, error) {
if buf == nil {
return typ.SQL(ctx, nil, val)
}
ret, err := typ.SQL(ctx, buf.Get(), val)
buf.Grow(ret.Len())
return ret, err
}

func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression, buf *sql.ByteBuffer) ([]sqltypes.Value, error) {
// need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock)
if len(sch) == 0 {
return []sqltypes.Value{}, nil
}

outVals := make([]sqltypes.Value, len(sch))
var err error
if len(projs) == 0 {
for i, col := range sch {
if row[i] == nil {
outVals[i] = sqltypes.NULL
continue
}
var err error
outVals[i], err = col.Type.SQL(ctx, nil, row[i])
outVals[i], err = toSqlHelper(ctx, col.Type, buf, row[i])
if err != nil {
return nil, err
}
Expand All @@ -963,7 +971,7 @@ func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Express
outVals[i] = sqltypes.NULL
continue
}
outVals[i], err = col.Type.SQL(ctx, nil, field)
outVals[i], err = toSqlHelper(ctx, col.Type, buf, field)
if err != nil {
return nil, err
}
Expand Down
16 changes: 12 additions & 4 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"io"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -789,13 +790,14 @@ func TestHandlerKillQuery(t *testing.T) {

var wg sync.WaitGroup
wg.Add(1)
sleepQuery := "SELECT SLEEP(1)"
sleepQuery := "SELECT SLEEP(100000)"
var sleepErr error
go func() {
defer wg.Done()
err = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error {
// need a local |err| variable to avoid being overwritten
sleepErr = handler.ComQuery(context.Background(), conn1, sleepQuery, func(res *sqltypes.Result, more bool) error {
return nil
})
require.Error(err)
}()

time.Sleep(100 * time.Millisecond)
Expand All @@ -805,12 +807,17 @@ func TestHandlerKillQuery(t *testing.T) {
// 2, , , test, Query, 0, running, SHOW PROCESSLIST
require.Equal(2, len(res.Rows))
hasSleepQuery := false
fmt.Println(res.Rows[0][0], res.Rows[0][4], res.Rows[0][7])
fmt.Println(res.Rows[1][0], res.Rows[1][4], res.Rows[1][7])
for _, row := range res.Rows {
if row[7].ToString() != sleepQuery {
continue
}
hasSleepQuery = true
sleepQueryID = row[0].ToString()
// the values inside a callback are generally only valid for the
// duration of the query, and need to be copied to avoid being
// overwritten
sleepQueryID = strings.Clone(row[0].ToString())
require.Equal("Query", row[4].ToString())
}
require.True(hasSleepQuery)
Expand All @@ -824,6 +831,7 @@ func TestHandlerKillQuery(t *testing.T) {
})
require.NoError(err)
wg.Wait()
require.Error(sleepErr)

time.Sleep(100 * time.Millisecond)
err = handler.ComQuery(context.Background(), conn2, "SHOW PROCESSLIST", func(res *sqltypes.Result, more bool) error {
Expand Down
75 changes: 75 additions & 0 deletions sql/byte_buffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

import (
"sync"
)

const defaultByteBuffCap = 4096

var ByteBufPool = sync.Pool{
New: func() any {
return NewByteBuffer(defaultByteBuffCap)
},
}

type ByteBuffer struct {
i int
buf []byte
}

func NewByteBuffer(initCap int) *ByteBuffer {
buf := make([]byte, initCap)
return &ByteBuffer{buf: buf}
}

// Grow records the latest used byte position. Callers
// are responsible for accurately reporting which bytes
// they expect to be protected.
func (b *ByteBuffer) Grow(n int) {
newI := b.i
if b.i+n <= len(b.buf) {
// Increment |b.i| if no alloc
newI += n
}
if b.i+n >= len(b.buf) {
// No more space, double.
// An external allocation doubled the cap using the size of
// the override object, which if used could lead to overall
// shrinking behavior.
b.Double()
}
b.i = newI
}

// Double expands the backing array by 2x. We do this
// here because the runtime only doubles based on slice
// length.
func (b *ByteBuffer) Double() {
buf := make([]byte, len(b.buf)*2)
copy(buf, b.buf)
b.buf = buf
}

// Get returns a zero length slice beginning at a safe
// write position.
func (b *ByteBuffer) Get() []byte {
return b.buf[b.i:b.i]
}

func (b *ByteBuffer) Reset() {
b.i = 0
}
72 changes: 72 additions & 0 deletions sql/byte_buffer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestGrowByteBuffer(t *testing.T) {
b := NewByteBuffer(10)

// grow less than boundary
src1 := []byte{1, 1, 1}
obj1 := append(b.Get(), src1...)
b.Grow(len(src1))

require.Equal(t, 10, len(b.buf))
require.Equal(t, 3, b.i)
require.Equal(t, 10, cap(obj1))

// grow to boundary
src2 := []byte{0, 0, 0, 0, 0, 0, 0}
obj2 := append(b.Get(), src2...)
b.Grow(len(src2))

require.Equal(t, 20, len(b.buf))
require.Equal(t, 10, b.i)
require.Equal(t, 7, cap(obj2))

src3 := []byte{2, 2, 2, 2, 2}
obj3 := append(b.Get(), src3...)
b.Grow(len(src3))

require.Equal(t, 20, len(b.buf))
require.Equal(t, 15, b.i)
require.Equal(t, 10, cap(obj3))

// grow exceeds boundary

src4 := []byte{3, 3, 3, 3, 3, 3, 3, 3}
obj4 := append(b.Get(), src4...)
b.Grow(len(src4))

require.Equal(t, 40, len(b.buf))
require.Equal(t, 15, b.i)
require.Equal(t, 16, cap(obj4))

// objects are all valid after doubling
require.Equal(t, src1, obj1)
require.Equal(t, src2, obj2)
require.Equal(t, src3, obj3)
require.Equal(t, src4, obj4)

// reset
b.Reset()
require.Equal(t, 40, len(b.buf))
require.Equal(t, 0, b.i)
}
2 changes: 1 addition & 1 deletion sql/types/bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 {
data[i], data[j] = data[j], data[i]
}
val := AppendAndSliceBytes(dest, data)
val := data

return sqltypes.MakeTrusted(sqltypes.Bit, val), nil
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ func (t datetimeType) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes.
return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime")
}

valBytes := AppendAndSliceBytes(dest, val)
valBytes := val

return sqltypes.MakeTrusted(typ, valBytes), nil
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (t EnumType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
snippet = strings.ToValidUTF8(snippet, string(utf8.RuneError))
return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), snippet)
}
val := AppendAndSliceBytes(dest, encodedBytes)
val := encodedBytes

return sqltypes.MakeTrusted(sqltypes.Enum, val), nil
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (t JsonType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
if err != nil {
return sqltypes.NULL, err
}
val = AppendAndSliceBytes(dest, str)
val = str
} else {
// Convert to jsonType
jsVal, _, err := t.Convert(v)
Expand Down
2 changes: 2 additions & 0 deletions sql/types/number.go
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,8 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt
default:
return sqltypes.Value{}, err
}
} else if err != nil {
return sqltypes.Value{}, err
}

val := dest[stop:]
Expand Down
Loading

0 comments on commit 999a371

Please sign in to comment.