Skip to content

Commit

Permalink
Merge pull request #130 from tursodatabase/streaming_batch
Browse files Browse the repository at this point in the history
Execute big batches in smaller chunks
  • Loading branch information
haaawk authored Jul 23, 2024
2 parents 9bc6b51 + f3e03d5 commit b944339
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 17 deletions.
147 changes: 133 additions & 14 deletions libsql/internal/http/hranaV2/hranaV2.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/tursodatabase/libsql-client-go/sqliteparserutils"
"io"
"net/http"
net_url "net/url"
Expand Down Expand Up @@ -286,10 +287,131 @@ func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url st
return result, false, nil
}

func (h *hranaV2Conn) executeMsg(ctx context.Context, msg *hrana.PipelineRequest) (*hrana.PipelineResponse, error) {
result, err := h.sendPipelineRequest(ctx, msg, false)
if err != nil {
return nil, err
}

for _, r := range result.Results {
if r.Error != nil {
return nil, errors.New(r.Error.Message)
}
if r.Response == nil {
return nil, errors.New("no response received")
}
}
return result, nil
}

type chunker struct {
chunk []string
iterator *sqliteparserutils.StatementIterator
limit int
}

func newChunker(iterator *sqliteparserutils.StatementIterator, limit int) *chunker {
return &chunker{iterator: iterator, chunk: make([]string, 0, limit), limit: limit}
}

func isTransactionStatement(stmt string) bool {
patterns := [][]byte{[]byte("begin"), []byte("commit"), []byte("end"), []byte("rollback")}
for _, p := range patterns {
if len(stmt) >= len(p) && bytes.Equal(bytes.ToLower([]byte(stmt[0:len(p)])), p) {
return true
}
}
return false
}

func (c *chunker) Next() (chunk []string, isEOF bool) {
c.chunk = c.chunk[:0]
var stmt string
for !isEOF && len(c.chunk) < c.limit {
stmt, _, isEOF = c.iterator.Next()
// We need to skip transaction statements. Chunks run in a transaction by default.
if stmt != "" && !isTransactionStatement(stmt) {
c.chunk = append(c.chunk, stmt)
}
}
return c.chunk, isEOF
}

func (h *hranaV2Conn) executeSingleStmt(ctx context.Context, stmt string, wantRows bool) (*hrana.PipelineResponse, error) {
msg := &hrana.PipelineRequest{}
executeStream, err := hrana.ExecuteStream(stmt, nil, wantRows)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", stmt, err)
}
msg.Add(*executeStream)
res, err := h.executeMsg(ctx, msg)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", stmt, err)
}
return res, nil
}

func (h *hranaV2Conn) executeInChunks(ctx context.Context, query string, wantRows bool) (*hrana.PipelineResponse, error) {
const chunkSize = 4096
iterator := sqliteparserutils.CreateStatementIterator(query)
chunker := newChunker(iterator, chunkSize)

chunk, isEOF := chunker.Next()
if isEOF && len(chunk) == 1 {
return h.executeSingleStmt(ctx, chunk[0], wantRows)
}

_, err := h.executeSingleStmt(ctx, "BEGIN", false)
if err != nil {
return nil, err
}

batch := &hrana.Batch{Steps: make([]hrana.BatchStep, chunkSize)}
msg := &hrana.PipelineRequest{}
msg.Add(hrana.StreamRequest{Type: "batch", Batch: batch})
for idx := range batch.Steps {
batch.Steps[idx].Stmt.WantRows = wantRows
}

result := &hrana.PipelineResponse{}
for {
for idx := range chunk {
batch.Steps[idx].Stmt.Sql = &chunk[idx]
}
if len(chunk) < chunkSize {
// We can trim batch.Steps because this is the last chunk anyway.
// isEOF has to be true at this point.
batch.Steps = batch.Steps[:len(chunk)]
}
res, err := h.executeMsg(ctx, msg)
if err != nil {
h.closeStream()
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
}
result.Baton = res.Baton
result.BaseUrl = res.BaseUrl
result.Results = append(result.Results, res.Results...)
if isEOF {
break
}
chunk, isEOF = chunker.Next()
}
_, err = h.executeSingleStmt(ctx, "COMMIT", false)
if err != nil {
h.closeStream()
return nil, err
}
return result, nil
}

func (h *hranaV2Conn) executeStmt(ctx context.Context, query string, args []driver.NamedValue, wantRows bool) (*hrana.PipelineResponse, error) {
const querySizeLimitForChunking = 20 * 1024 * 1024
if len(args) == 0 && len(query) > querySizeLimitForChunking && !h.schemaDb {
return h.executeInChunks(ctx, query, wantRows)
}
stmts, params, err := shared.ParseStatementAndArgs(query, args)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", query, err)
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
}
msg := &hrana.PipelineRequest{}
if len(stmts) == 1 {
Expand All @@ -299,29 +421,22 @@ func (h *hranaV2Conn) executeStmt(ctx context.Context, query string, args []driv
}
executeStream, err := hrana.ExecuteStream(stmts[0], p, wantRows)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", query, err)
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
}
msg.Add(*executeStream)
} else {
batchStream, err := hrana.BatchStream(stmts, params, wantRows, !h.schemaDb)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", query, err)
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
}
msg.Add(*batchStream)
}

result, err := h.sendPipelineRequest(ctx, msg, false)
resp, err := h.executeMsg(ctx, msg)
if err != nil {
return nil, fmt.Errorf("failed to execute SQL: %s\n%w", query, err)
return nil, fmt.Errorf("failed to execute SQL:\n%w", err)
}

if result.Results[0].Error != nil {
return nil, fmt.Errorf("failed to execute SQL: %s\n%s", query, result.Results[0].Error.Message)
}
if result.Results[0].Response == nil {
return nil, fmt.Errorf("failed to execute SQL: %s\n%s", query, "no response received")
}
return result, nil
return resp, nil
}

func (h *hranaV2Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
Expand Down Expand Up @@ -477,7 +592,7 @@ func (h *hranaV2Conn) QueryContext(ctx context.Context, query string, args []dri
}
}

func (h *hranaV2Conn) ResetSession(ctx context.Context) error {
func (h *hranaV2Conn) closeStream() {
if h.baton != "" {
go func(baton, url, jwt, host string) {
msg := hrana.PipelineRequest{Baton: baton}
Expand All @@ -486,5 +601,9 @@ func (h *hranaV2Conn) ResetSession(ctx context.Context) error {
}(h.baton, h.url, h.jwt, h.host)
h.baton = ""
}
}

func (h *hranaV2Conn) ResetSession(ctx context.Context) error {
h.closeStream()
return nil
}
5 changes: 2 additions & 3 deletions sqliteparserutils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ type StatementIterator struct {
currentToken antlr.Token
}

// keep createStatementIterator here for the future uses but do not expose it for now as we will not use it immediately
func createStatementIterator(statement string) *StatementIterator {
func CreateStatementIterator(statement string) *StatementIterator {
return &StatementIterator{tokenizer: createStringTokenizer(statement)}
}

Expand Down Expand Up @@ -77,7 +76,7 @@ func (iterator *StatementIterator) Next() (statement string, extraInfo SplitStat
}

func SplitStatement(statement string) (statements []string, extraInfo SplitStatementExtraInfo) {
iterator := createStatementIterator(statement)
iterator := CreateStatementIterator(statement)

statements = make([]string, 0)
for {
Expand Down

0 comments on commit b944339

Please sign in to comment.