-
Notifications
You must be signed in to change notification settings - Fork 1k
client,mysql: Add support for Query Attributes #976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
740b4c6
client,mysql: Add support for Query Attributes
dveeden 1390b58
fixup
dveeden c833ce4
Merge branch 'master' into query_attributes
lance6716 f562050
Add includeLine option
dveeden 4cbdd13
Add _client_version connattr
dveeden deb4492
Fix nil args
dveeden 2a29ba0
update
dveeden 6f41c36
update
dveeden 5ad9e06
fix flags
dveeden 3143ff8
update
dveeden a4f390b
updates
dveeden 2d288d4
server: accept PARAMETER_COUNT_AVAILABLE flag
dveeden 5b4d37f
fixup
dveeden f6fda23
Merge remote-tracking branch 'upstream/master' into query_attributes
dveeden 9e75208
Apply suggestions from code review
dveeden 7cad496
Allow setting of frame
dveeden d330e39
Apply suggestion
dveeden 25f3bae
fixup
dveeden 9865fd5
Update based on review
dveeden e1b8e65
Add tests and fix error
dveeden 16b678e
more tests
dveeden e939d95
Add another test
dveeden 622b0ba
Fix linter found issues
dveeden 7b93b16
Merge branch 'master' into query_attributes
lance6716 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,11 @@ type Conn struct { | |
authPluginName string | ||
|
||
connectionID uint32 | ||
|
||
queryAttributes []QueryAttribute | ||
|
||
// Include the file + line as query attribute. The number set which frame in the stack should be used. | ||
includeLine int | ||
} | ||
|
||
// This function will be called for every row in resultset from ExecuteSelectStreaming. | ||
|
@@ -100,6 +105,7 @@ type Dialer func(ctx context.Context, network, address string) (net.Conn, error) | |
func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) { | ||
c := new(Conn) | ||
|
||
c.includeLine = -1 | ||
c.BufferSize = defaultBufferSize | ||
c.attributes = map[string]string{ | ||
"_client_name": "go-mysql", | ||
|
@@ -310,7 +316,7 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { | |
// flag set to signal the server multiple queries are executed. Handling the responses | ||
// is up to the implementation of perResultCallback. | ||
func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) { | ||
if err := c.writeCommandStr(COM_QUERY, query); err != nil { | ||
if err := c.execSend(query); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
|
@@ -363,7 +369,7 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall | |
// | ||
// ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving. | ||
func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { | ||
if err := c.writeCommandStr(COM_QUERY, command); err != nil { | ||
if err := c.execSend(command); err != nil { | ||
return errors.Trace(err) | ||
} | ||
|
||
|
@@ -489,14 +495,64 @@ func (c *Conn) ReadOKPacket() (*Result, error) { | |
return c.readOK() | ||
} | ||
|
||
// Send COM_QUERY and read the result | ||
func (c *Conn) exec(query string) (*Result, error) { | ||
if err := c.writeCommandStr(COM_QUERY, query); err != nil { | ||
err := c.execSend(query) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
return c.readResult(false) | ||
} | ||
|
||
// Sends COM_QUERY | ||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html | ||
func (c *Conn) execSend(query string) error { | ||
var buf bytes.Buffer | ||
|
||
if c.includeLine >= 0 { | ||
dveeden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_, file, line, ok := runtime.Caller(c.includeLine) | ||
if ok { | ||
lineAttr := QueryAttribute{ | ||
Name: "_line", | ||
Value: fmt.Sprintf("%s:%d", file, line), | ||
} | ||
c.queryAttributes = append(c.queryAttributes, lineAttr) | ||
} | ||
} | ||
|
||
if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 { | ||
numParams := len(c.queryAttributes) | ||
buf.Write(PutLengthEncodedInt(uint64(numParams))) | ||
buf.WriteByte(0x1) // parameter_set_count, unused | ||
if numParams > 0 { | ||
// null_bitmap, length: (num_params+7)/8 | ||
for i := 0; i < (numParams+7)/8; i++ { | ||
buf.WriteByte(0x0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we allow NULL query attributes so the bitmap will have |
||
} | ||
buf.WriteByte(0x1) // new_params_bind_flag, unused | ||
for _, qa := range c.queryAttributes { | ||
buf.Write(qa.TypeAndFlag()) | ||
buf.Write(PutLengthEncodedString([]byte(qa.Name))) | ||
} | ||
for _, qa := range c.queryAttributes { | ||
buf.Write(qa.ValueBytes()) | ||
} | ||
} | ||
} | ||
|
||
_, err := buf.Write(utils.StringToByteSlice(query)) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil { | ||
return errors.Trace(err) | ||
} | ||
c.queryAttributes = nil | ||
dveeden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return nil | ||
} | ||
|
||
// CapabilityString is returning a string with the names of capability flags | ||
// separated by "|". Examples of capability names are CLIENT_DEPRECATE_EOF and CLIENT_PROTOCOL_41. | ||
// These are defined as constants in the mysql package. | ||
|
@@ -627,3 +683,17 @@ func (c *Conn) StatusString() string { | |
|
||
return strings.Join(stats, "|") | ||
} | ||
|
||
// SetQueryAttributes sets the query attributes to be send along with the next query | ||
func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error { | ||
c.queryAttributes = attrs | ||
return nil | ||
} | ||
|
||
// IncludeLine can be passed as option when connecting to include the file name and line number | ||
// of the caller as query attribute `_line` when sending queries. | ||
// The argument is used the dept in the stack. The top level is go-mysql and then there are the | ||
// levels of the application. | ||
func (c *Conn) IncludeLine(frame int) { | ||
c.includeLine = frame | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ import ( | |
"encoding/json" | ||
"fmt" | ||
"math" | ||
"runtime" | ||
|
||
. "github.com/go-mysql-org/go-mysql/mysql" | ||
"github.com/go-mysql-org/go-mysql/utils" | ||
|
@@ -56,18 +57,33 @@ func (s *Stmt) Close() error { | |
return nil | ||
} | ||
|
||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html | ||
func (s *Stmt) write(args ...interface{}) error { | ||
paramsNum := s.params | ||
|
||
if len(args) != paramsNum { | ||
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) | ||
} | ||
|
||
paramTypes := make([]byte, paramsNum<<1) | ||
paramValues := make([][]byte, paramsNum) | ||
if s.conn.includeLine >= 0 { | ||
dveeden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_, file, line, ok := runtime.Caller(s.conn.includeLine) | ||
if ok { | ||
lineAttr := QueryAttribute{ | ||
Name: "_line", | ||
Value: fmt.Sprintf("%s:%d", file, line), | ||
} | ||
s.conn.queryAttributes = append(s.conn.queryAttributes, lineAttr) | ||
} | ||
} | ||
|
||
qaLen := len(s.conn.queryAttributes) | ||
paramTypes := make([][]byte, paramsNum+qaLen) | ||
paramFlags := make([][]byte, paramsNum+qaLen) | ||
paramValues := make([][]byte, paramsNum+qaLen) | ||
paramNames := make([][]byte, paramsNum+qaLen) | ||
|
||
//NULL-bitmap, length: (num-params+7) | ||
nullBitmap := make([]byte, (paramsNum+7)>>3) | ||
nullBitmap := make([]byte, (paramsNum+qaLen+7)>>3) | ||
|
||
length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1) | ||
|
||
|
@@ -76,76 +92,89 @@ func (s *Stmt) write(args ...interface{}) error { | |
for i := range args { | ||
if args[i] == nil { | ||
nullBitmap[i/8] |= 1 << (uint(i) % 8) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this part is a bit complicated, I'll review later 😂 |
||
paramTypes[i<<1] = MYSQL_TYPE_NULL | ||
paramTypes[i] = []byte{MYSQL_TYPE_NULL} | ||
paramNames[i] = []byte{0} // length encoded, no name | ||
paramFlags[i] = []byte{0} | ||
continue | ||
} | ||
|
||
newParamBoundFlag = 1 | ||
|
||
switch v := args[i].(type) { | ||
case int8: | ||
paramTypes[i<<1] = MYSQL_TYPE_TINY | ||
paramTypes[i] = []byte{MYSQL_TYPE_TINY} | ||
paramValues[i] = []byte{byte(v)} | ||
case int16: | ||
paramTypes[i<<1] = MYSQL_TYPE_SHORT | ||
paramTypes[i] = []byte{MYSQL_TYPE_SHORT} | ||
paramValues[i] = Uint16ToBytes(uint16(v)) | ||
case int32: | ||
paramTypes[i<<1] = MYSQL_TYPE_LONG | ||
paramTypes[i] = []byte{MYSQL_TYPE_LONG} | ||
paramValues[i] = Uint32ToBytes(uint32(v)) | ||
case int: | ||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||
paramValues[i] = Uint64ToBytes(uint64(v)) | ||
case int64: | ||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||
paramValues[i] = Uint64ToBytes(uint64(v)) | ||
case uint8: | ||
paramTypes[i<<1] = MYSQL_TYPE_TINY | ||
paramTypes[(i<<1)+1] = 0x80 | ||
paramTypes[i] = []byte{MYSQL_TYPE_TINY} | ||
paramFlags[i] = []byte{PARAM_UNSIGNED} | ||
paramValues[i] = []byte{v} | ||
case uint16: | ||
paramTypes[i<<1] = MYSQL_TYPE_SHORT | ||
paramTypes[(i<<1)+1] = 0x80 | ||
paramTypes[i] = []byte{MYSQL_TYPE_SHORT} | ||
paramFlags[i] = []byte{PARAM_UNSIGNED} | ||
paramValues[i] = Uint16ToBytes(v) | ||
case uint32: | ||
paramTypes[i<<1] = MYSQL_TYPE_LONG | ||
paramTypes[(i<<1)+1] = 0x80 | ||
paramTypes[i] = []byte{MYSQL_TYPE_LONG} | ||
paramFlags[i] = []byte{PARAM_UNSIGNED} | ||
paramValues[i] = Uint32ToBytes(v) | ||
case uint: | ||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||
paramTypes[(i<<1)+1] = 0x80 | ||
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||
paramFlags[i] = []byte{PARAM_UNSIGNED} | ||
paramValues[i] = Uint64ToBytes(uint64(v)) | ||
case uint64: | ||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||
paramTypes[(i<<1)+1] = 0x80 | ||
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||
paramFlags[i] = []byte{PARAM_UNSIGNED} | ||
paramValues[i] = Uint64ToBytes(v) | ||
case bool: | ||
paramTypes[i<<1] = MYSQL_TYPE_TINY | ||
paramTypes[i] = []byte{MYSQL_TYPE_TINY} | ||
if v { | ||
paramValues[i] = []byte{1} | ||
} else { | ||
paramValues[i] = []byte{0} | ||
} | ||
case float32: | ||
paramTypes[i<<1] = MYSQL_TYPE_FLOAT | ||
paramTypes[i] = []byte{MYSQL_TYPE_FLOAT} | ||
paramValues[i] = Uint32ToBytes(math.Float32bits(v)) | ||
case float64: | ||
paramTypes[i<<1] = MYSQL_TYPE_DOUBLE | ||
paramTypes[i] = []byte{MYSQL_TYPE_DOUBLE} | ||
paramValues[i] = Uint64ToBytes(math.Float64bits(v)) | ||
case string: | ||
paramTypes[i<<1] = MYSQL_TYPE_STRING | ||
paramTypes[i] = []byte{MYSQL_TYPE_STRING} | ||
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) | ||
case []byte: | ||
paramTypes[i<<1] = MYSQL_TYPE_STRING | ||
paramTypes[i] = []byte{MYSQL_TYPE_STRING} | ||
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) | ||
case json.RawMessage: | ||
paramTypes[i<<1] = MYSQL_TYPE_STRING | ||
paramTypes[i] = []byte{MYSQL_TYPE_STRING} | ||
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) | ||
default: | ||
return fmt.Errorf("invalid argument type %T", args[i]) | ||
} | ||
paramNames[i] = []byte{0} // length encoded, no name | ||
if paramFlags[i] == nil { | ||
paramFlags[i] = []byte{0} | ||
} | ||
|
||
length += len(paramValues[i]) | ||
} | ||
for i, qa := range s.conn.queryAttributes { | ||
tf := qa.TypeAndFlag() | ||
paramTypes[(i + paramsNum)] = []byte{tf[0]} | ||
paramFlags[i+paramsNum] = []byte{tf[1]} | ||
paramValues[i+paramsNum] = qa.ValueBytes() | ||
paramNames[i+paramsNum] = PutLengthEncodedString([]byte(qa.Name)) | ||
} | ||
|
||
data := utils.BytesBufferGet() | ||
defer func() { | ||
|
@@ -159,30 +188,46 @@ func (s *Stmt) write(args ...interface{}) error { | |
data.WriteByte(COM_STMT_EXECUTE) | ||
data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)}) | ||
|
||
//flag: CURSOR_TYPE_NO_CURSOR | ||
data.WriteByte(0x00) | ||
flags := CURSOR_TYPE_NO_CURSOR | ||
if paramsNum > 0 { | ||
flags |= PARAMETER_COUNT_AVAILABLE | ||
} | ||
data.WriteByte(flags) | ||
|
||
//iteration-count, always 1 | ||
data.Write([]byte{1, 0, 0, 0}) | ||
|
||
if s.params > 0 { | ||
data.Write(nullBitmap) | ||
|
||
//new-params-bound-flag | ||
data.WriteByte(newParamBoundFlag) | ||
|
||
if newParamBoundFlag == 1 { | ||
//type of each parameter, length: num-params * 2 | ||
data.Write(paramTypes) | ||
|
||
//value of each parameter | ||
for _, v := range paramValues { | ||
data.Write(v) | ||
if paramsNum > 0 || (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && (flags&PARAMETER_COUNT_AVAILABLE > 0)) { | ||
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { | ||
paramsNum += len(s.conn.queryAttributes) | ||
data.Write(PutLengthEncodedInt(uint64(paramsNum))) | ||
} | ||
if paramsNum > 0 { | ||
data.Write(nullBitmap) | ||
|
||
//new-params-bound-flag | ||
data.WriteByte(newParamBoundFlag) | ||
|
||
if newParamBoundFlag == 1 { | ||
for i := 0; i < paramsNum; i++ { | ||
data.Write(paramTypes[i]) | ||
data.Write(paramFlags[i]) | ||
|
||
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { | ||
data.Write(paramNames[i]) | ||
} | ||
} | ||
|
||
//value of each parameter | ||
for _, v := range paramValues { | ||
data.Write(v) | ||
} | ||
} | ||
} | ||
} | ||
|
||
s.conn.ResetSequence() | ||
s.conn.queryAttributes = nil | ||
|
||
return s.conn.WritePacket(data.Bytes()) | ||
} | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.