-
Notifications
You must be signed in to change notification settings - Fork 1k
client,mysql: Add support for Query Attributes #976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
740b4c6
1390b58
c833ce4
f562050
4cbdd13
deb4492
2a29ba0
6f41c36
5ad9e06
3143ff8
a4f390b
2d288d4
5b4d37f
f6fda23
9e75208
7cad496
d330e39
25f3bae
9865fd5
e1b8e65
16b678e
e939d95
622b0ba
7b93b16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,11 @@ type Conn struct { | |
authPluginName string | ||
|
||
connectionID uint32 | ||
|
||
queryAttributes []QueryAttribute | ||
|
||
// Include the file + line as query attribute. The number set which frame in the stack should be used. | ||
includeLine int | ||
} | ||
|
||
// This function will be called for every row in resultset from ExecuteSelectStreaming. | ||
|
@@ -100,6 +105,7 @@ type Dialer func(ctx context.Context, network, address string) (net.Conn, error) | |
func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) { | ||
c := new(Conn) | ||
|
||
c.includeLine = -1 | ||
c.BufferSize = defaultBufferSize | ||
c.attributes = map[string]string{ | ||
"_client_name": "go-mysql", | ||
|
@@ -310,7 +316,7 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { | |
// flag set to signal the server multiple queries are executed. Handling the responses | ||
// is up to the implementation of perResultCallback. | ||
func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) { | ||
if err := c.writeCommandStr(COM_QUERY, query); err != nil { | ||
if err := c.execSend(query); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
|
@@ -363,7 +369,7 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall | |
// | ||
// ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving. | ||
func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { | ||
if err := c.writeCommandStr(COM_QUERY, command); err != nil { | ||
if err := c.execSend(command); err != nil { | ||
return errors.Trace(err) | ||
} | ||
|
||
|
@@ -489,14 +495,64 @@ func (c *Conn) ReadOKPacket() (*Result, error) { | |
return c.readOK() | ||
} | ||
|
||
// Send COM_QUERY and read the result | ||
func (c *Conn) exec(query string) (*Result, error) { | ||
if err := c.writeCommandStr(COM_QUERY, query); err != nil { | ||
err := c.execSend(query) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
return c.readResult(false) | ||
} | ||
|
||
// Sends COM_QUERY | ||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html | ||
func (c *Conn) execSend(query string) error { | ||
var buf bytes.Buffer | ||
defer clear(c.queryAttributes) | ||
|
||
if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 { | ||
if c.includeLine >= 0 { | ||
_, file, line, ok := runtime.Caller(c.includeLine) | ||
if ok { | ||
lineAttr := QueryAttribute{ | ||
Name: "_line", | ||
Value: fmt.Sprintf("%s:%d", file, line), | ||
} | ||
c.queryAttributes = append(c.queryAttributes, lineAttr) | ||
} | ||
} | ||
|
||
numParams := len(c.queryAttributes) | ||
buf.Write(PutLengthEncodedInt(uint64(numParams))) | ||
buf.WriteByte(0x1) // parameter_set_count, unused | ||
if numParams > 0 { | ||
// null_bitmap, length: (num_params+7)/8 | ||
for i := 0; i < (numParams+7)/8; i++ { | ||
buf.WriteByte(0x0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we allow NULL query attributes so the bitmap will have |
||
} | ||
buf.WriteByte(0x1) // new_params_bind_flag, unused | ||
for _, qa := range c.queryAttributes { | ||
buf.Write(qa.TypeAndFlag()) | ||
buf.Write(PutLengthEncodedString([]byte(qa.Name))) | ||
} | ||
for _, qa := range c.queryAttributes { | ||
buf.Write(qa.ValueBytes()) | ||
} | ||
} | ||
} | ||
|
||
_, err := buf.Write(utils.StringToByteSlice(query)) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil { | ||
return errors.Trace(err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// CapabilityString is returning a string with the names of capability flags | ||
// separated by "|". Examples of capability names are CLIENT_DEPRECATE_EOF and CLIENT_PROTOCOL_41. | ||
// These are defined as constants in the mysql package. | ||
|
@@ -627,3 +683,17 @@ func (c *Conn) StatusString() string { | |
|
||
return strings.Join(stats, "|") | ||
} | ||
|
||
// SetQueryAttributes sets the query attributes to be send along with the next query | ||
func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error { | ||
c.queryAttributes = attrs | ||
return nil | ||
} | ||
|
||
// IncludeLine can be passed as option when connecting to include the file name and line number | ||
// of the caller as query attribute `_line` when sending queries. | ||
// The argument is used the dept in the stack. The top level is go-mysql and then there are the | ||
// levels of the application. | ||
func (c *Conn) IncludeLine(frame int) { | ||
c.includeLine = frame | ||
} |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,6 +5,7 @@ import ( | |||||
"encoding/json" | ||||||
"fmt" | ||||||
"math" | ||||||
"runtime" | ||||||
|
||||||
. "github.com/go-mysql-org/go-mysql/mysql" | ||||||
"github.com/go-mysql-org/go-mysql/utils" | ||||||
|
@@ -56,18 +57,34 @@ func (s *Stmt) Close() error { | |||||
return nil | ||||||
} | ||||||
|
||||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html | ||||||
func (s *Stmt) write(args ...interface{}) error { | ||||||
defer clear(s.conn.queryAttributes) | ||||||
paramsNum := s.params | ||||||
|
||||||
if len(args) != paramsNum { | ||||||
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) | ||||||
} | ||||||
|
||||||
paramTypes := make([]byte, paramsNum<<1) | ||||||
paramValues := make([][]byte, paramsNum) | ||||||
if (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
the old code has clear priority. It's OK to keep it. |
||||||
_, file, line, ok := runtime.Caller(s.conn.includeLine) | ||||||
if ok { | ||||||
lineAttr := QueryAttribute{ | ||||||
Name: "_line", | ||||||
Value: fmt.Sprintf("%s:%d", file, line), | ||||||
} | ||||||
s.conn.queryAttributes = append(s.conn.queryAttributes, lineAttr) | ||||||
} | ||||||
} | ||||||
|
||||||
qaLen := len(s.conn.queryAttributes) | ||||||
paramTypes := make([][]byte, paramsNum+qaLen) | ||||||
paramFlags := make([][]byte, paramsNum+qaLen) | ||||||
paramValues := make([][]byte, paramsNum+qaLen) | ||||||
paramNames := make([][]byte, paramsNum+qaLen) | ||||||
|
||||||
//NULL-bitmap, length: (num-params+7) | ||||||
nullBitmap := make([]byte, (paramsNum+7)>>3) | ||||||
nullBitmap := make([]byte, (paramsNum+qaLen+7)>>3) | ||||||
|
||||||
length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1) | ||||||
|
||||||
|
@@ -76,76 +93,89 @@ func (s *Stmt) write(args ...interface{}) error { | |||||
for i := range args { | ||||||
if args[i] == nil { | ||||||
nullBitmap[i/8] |= 1 << (uint(i) % 8) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this part is a bit complicated, I'll review later 😂 |
||||||
paramTypes[i<<1] = MYSQL_TYPE_NULL | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_NULL} | ||||||
paramNames[i] = []byte{0} // length encoded, no name | ||||||
paramFlags[i] = []byte{0} | ||||||
continue | ||||||
} | ||||||
|
||||||
newParamBoundFlag = 1 | ||||||
|
||||||
switch v := args[i].(type) { | ||||||
case int8: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_TINY | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_TINY} | ||||||
paramValues[i] = []byte{byte(v)} | ||||||
case int16: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_SHORT | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_SHORT} | ||||||
paramValues[i] = Uint16ToBytes(uint16(v)) | ||||||
case int32: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_LONG | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_LONG} | ||||||
paramValues[i] = Uint32ToBytes(uint32(v)) | ||||||
case int: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||||||
paramValues[i] = Uint64ToBytes(uint64(v)) | ||||||
case int64: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||||||
paramValues[i] = Uint64ToBytes(uint64(v)) | ||||||
case uint8: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_TINY | ||||||
paramTypes[(i<<1)+1] = 0x80 | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_TINY} | ||||||
paramFlags[i] = []byte{PARAM_UNSIGNED} | ||||||
paramValues[i] = []byte{v} | ||||||
case uint16: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_SHORT | ||||||
paramTypes[(i<<1)+1] = 0x80 | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_SHORT} | ||||||
paramFlags[i] = []byte{PARAM_UNSIGNED} | ||||||
paramValues[i] = Uint16ToBytes(v) | ||||||
case uint32: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_LONG | ||||||
paramTypes[(i<<1)+1] = 0x80 | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_LONG} | ||||||
paramFlags[i] = []byte{PARAM_UNSIGNED} | ||||||
paramValues[i] = Uint32ToBytes(v) | ||||||
case uint: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||||||
paramTypes[(i<<1)+1] = 0x80 | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||||||
paramFlags[i] = []byte{PARAM_UNSIGNED} | ||||||
paramValues[i] = Uint64ToBytes(uint64(v)) | ||||||
case uint64: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||||||
paramTypes[(i<<1)+1] = 0x80 | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||||||
paramFlags[i] = []byte{PARAM_UNSIGNED} | ||||||
paramValues[i] = Uint64ToBytes(v) | ||||||
case bool: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_TINY | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_TINY} | ||||||
if v { | ||||||
paramValues[i] = []byte{1} | ||||||
} else { | ||||||
paramValues[i] = []byte{0} | ||||||
} | ||||||
case float32: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_FLOAT | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_FLOAT} | ||||||
paramValues[i] = Uint32ToBytes(math.Float32bits(v)) | ||||||
case float64: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_DOUBLE | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_DOUBLE} | ||||||
paramValues[i] = Uint64ToBytes(math.Float64bits(v)) | ||||||
case string: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_STRING | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_STRING} | ||||||
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) | ||||||
case []byte: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_STRING | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_STRING} | ||||||
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) | ||||||
case json.RawMessage: | ||||||
paramTypes[i<<1] = MYSQL_TYPE_STRING | ||||||
paramTypes[i] = []byte{MYSQL_TYPE_STRING} | ||||||
paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) | ||||||
default: | ||||||
return fmt.Errorf("invalid argument type %T", args[i]) | ||||||
} | ||||||
paramNames[i] = []byte{0} // length encoded, no name | ||||||
if paramFlags[i] == nil { | ||||||
paramFlags[i] = []byte{0} | ||||||
} | ||||||
|
||||||
length += len(paramValues[i]) | ||||||
} | ||||||
for i, qa := range s.conn.queryAttributes { | ||||||
tf := qa.TypeAndFlag() | ||||||
paramTypes[(i + paramsNum)] = []byte{tf[0]} | ||||||
paramFlags[i+paramsNum] = []byte{tf[1]} | ||||||
paramValues[i+paramsNum] = qa.ValueBytes() | ||||||
paramNames[i+paramsNum] = PutLengthEncodedString([]byte(qa.Name)) | ||||||
} | ||||||
|
||||||
data := utils.BytesBufferGet() | ||||||
defer func() { | ||||||
|
@@ -159,25 +189,40 @@ func (s *Stmt) write(args ...interface{}) error { | |||||
data.WriteByte(COM_STMT_EXECUTE) | ||||||
data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)}) | ||||||
|
||||||
//flag: CURSOR_TYPE_NO_CURSOR | ||||||
data.WriteByte(0x00) | ||||||
flags := CURSOR_TYPE_NO_CURSOR | ||||||
if paramsNum > 0 { | ||||||
flags |= PARAMETER_COUNT_AVAILABLE | ||||||
} | ||||||
data.WriteByte(flags) | ||||||
|
||||||
//iteration-count, always 1 | ||||||
data.Write([]byte{1, 0, 0, 0}) | ||||||
|
||||||
if s.params > 0 { | ||||||
data.Write(nullBitmap) | ||||||
|
||||||
//new-params-bound-flag | ||||||
data.WriteByte(newParamBoundFlag) | ||||||
|
||||||
if newParamBoundFlag == 1 { | ||||||
//type of each parameter, length: num-params * 2 | ||||||
data.Write(paramTypes) | ||||||
|
||||||
//value of each parameter | ||||||
for _, v := range paramValues { | ||||||
data.Write(v) | ||||||
if paramsNum > 0 || (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && (flags&PARAMETER_COUNT_AVAILABLE > 0)) { | ||||||
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { | ||||||
paramsNum += len(s.conn.queryAttributes) | ||||||
data.Write(PutLengthEncodedInt(uint64(paramsNum))) | ||||||
} | ||||||
if paramsNum > 0 { | ||||||
data.Write(nullBitmap) | ||||||
|
||||||
//new-params-bound-flag | ||||||
data.WriteByte(newParamBoundFlag) | ||||||
|
||||||
if newParamBoundFlag == 1 { | ||||||
for i := 0; i < paramsNum; i++ { | ||||||
data.Write(paramTypes[i]) | ||||||
data.Write(paramFlags[i]) | ||||||
|
||||||
if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { | ||||||
data.Write(paramNames[i]) | ||||||
} | ||||||
} | ||||||
|
||||||
//value of each parameter | ||||||
for _, v := range paramValues { | ||||||
data.Write(v) | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.