Skip to content

Commit

Permalink
backend: add tests for prepared statements (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Aug 9, 2022
1 parent 3974260 commit 898d46a
Show file tree
Hide file tree
Showing 10 changed files with 811 additions and 108 deletions.
4 changes: 2 additions & 2 deletions pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func TestTLSConnection(t *testing.T) {
cfgOverriders := getCfgCombinations(cfgs)
for _, cfgs := range cfgOverriders {
ts, clean := newTestSuite(t, tc, cfgs...)
ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite, _, _, perr error) {
ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) {
if ts.mb.backendConfig.capability&mysql.ClientSSL == 0 {
require.ErrorContains(t, perr, "must enable TLS")
require.ErrorContains(t, ts.mp.err, "must enable TLS")
}
})
clean()
Expand Down
20 changes: 12 additions & 8 deletions pkg/proxy/backend/cmd_processor_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func (cp *CmdProcessor) forwardCommand(clientIO, backendIO *pnet.PacketIO, reque
return cp.forwardQueryCmd(clientIO, backendIO, request)
case mysql.ComStmtClose:
return cp.forwardCloseCmd(request)
case mysql.ComStmtSendLongData:
return cp.forwardSendLongDataCmd(request)
case mysql.ComChangeUser:
return cp.forwardChangeUserCmd(clientIO, backendIO, request)
case mysql.ComStatistics:
Expand Down Expand Up @@ -131,7 +133,7 @@ func (cp *CmdProcessor) forwardPrepareCmd(clientIO, backendIO *pnet.PacketIO) (s
succeed = true
}
for i := 0; i < expectedEOFNum; i++ {
// The server status in EOF packets is always 0, so ignore it.
// Ignore this status because PREPARE doesn't affect status.
if _, err = forwardUntilEOF(clientIO, backendIO); err != nil {
return
}
Expand Down Expand Up @@ -164,12 +166,8 @@ func (cp *CmdProcessor) forwardQueryCmd(clientIO, backendIO *pnet.PacketIO, requ
var serverStatus uint16
switch response[0] {
case mysql.OKHeader:
if err = clientIO.Flush(); err != nil {
return false, err
}
rs := cp.handleOKPacket(request, response)
serverStatus = rs.Status
succeed = true
serverStatus, succeed, err = rs.Status, true, clientIO.Flush()
case mysql.ErrHeader:
// Subsequent statements won't be executed even if it's a multi-statement.
return false, clientIO.Flush()
Expand Down Expand Up @@ -205,8 +203,7 @@ func (cp *CmdProcessor) forwardLoadInFile(clientIO, backendIO *pnet.PacketIO, re
}
}
var response []byte
response, err = forwardOnePacket(clientIO, backendIO, true)
if err != nil {
if response, err = forwardOnePacket(clientIO, backendIO, true); err != nil {
return
}
if response[0] == mysql.OKHeader {
Expand All @@ -233,6 +230,7 @@ func (cp *CmdProcessor) forwardResultSet(clientIO, backendIO *pnet.PacketIO, req
if response, err = forwardOnePacket(clientIO, backendIO, false); err != nil {
return
}
// An error may occur when the backend writes rows.
if response[0] == mysql.ErrHeader {
return 0, false, clientIO.Flush()
}
Expand All @@ -251,6 +249,12 @@ func (cp *CmdProcessor) forwardCloseCmd(request []byte) (succeed bool, err error
return true, nil
}

func (cp *CmdProcessor) forwardSendLongDataCmd(request []byte) (succeed bool, err error) {
// No packet is sent to the client for COM_STMT_SEND_LONG_DATA.
cp.updatePrepStmtStatus(request, 0)
return true, nil
}

func (cp *CmdProcessor) forwardChangeUserCmd(clientIO, backendIO *pnet.PacketIO, request []byte) (succeed bool, err error) {
// Currently, TiDB responses with an OK or Err packet. But according to the MySQL doc, the server may send a
// switch auth request.
Expand Down
2 changes: 2 additions & 0 deletions pkg/proxy/backend/cmd_processor_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (cp *CmdProcessor) query(packetIO *pnet.PacketIO, sql string) (result *gomy
return
}

// readResultSet is only used for reading the results of `show session_states` currently.
func (cp *CmdProcessor) readResultSet(packetIO *pnet.PacketIO, data []byte) (*gomysql.Result, error) {
columnCount, _, n := pnet.ParseLengthEncodedInt(data)
if n-len(data) != 0 {
Expand Down Expand Up @@ -114,6 +115,7 @@ func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql.
result.Status = binary.LittleEndian.Uint16(data[3:])
break
}
// An error may occur when the backend writes rows.
if data[0] == mysql.ErrHeader {
return cp.handleErrorPacket(data)
}
Expand Down
Loading

0 comments on commit 898d46a

Please sign in to comment.