Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: return error for new statement after opening cursor #40095

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions errno/errcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,7 @@ const (
ErrSetTTLEnableForNonTTLTable = 8150
ErrTempTableNotAllowedWithTTL = 8151
ErrUnsupportedTTLReferencedByFK = 8152
ErrNotAllowedWithActiveCursor = 8153

// Error codes used by TiDB ddl package
ErrUnsupportedDDLOperation = 8200
Expand Down
1 change: 1 addition & 0 deletions errno/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{
ErrInvalidRequiresSingleReference: mysql.Message("In recursive query block of Recursive Common Table Expression '%s', the recursive table must be referenced only once, and not in any subquery", nil),
ErrCTEMaxRecursionDepth: mysql.Message("Recursive query aborted after %d iterations. Try increasing @@cte_max_recursion_depth to a larger value", nil),
ErrTableWithoutPrimaryKey: mysql.Message("Unable to create or change a table without a primary key, when the system variable 'sql_require_primary_key' is set. Add a primary key to the table or unset this variable to avoid this message. Note that tables without a primary key can cause performance problems in row-based replication, so please consult your DBA before changing this setting.", nil),
ErrNotAllowedWithActiveCursor: mysql.Message("Commands other than CLOSE and FETCH are not allowed with active cursor", nil),
// MariaDB errors.
ErrOnlyOneDefaultPartionAllowed: mysql.Message("Only one DEFAULT partition allowed", nil),
ErrWrongPartitionTypeExpectedSystemTime: mysql.Message("Wrong partitioning type, expected type: `SYSTEM_TIME`", nil),
Expand Down
19 changes: 19 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,13 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
if cmd < mysql.ComEnd {
cc.ctx.SetCommandValue(cmd)
}
// only allow "FETCH" and "CLOSE" with an active cursor
// NOTE: commit is also not allowed. Users have to close the active statement before commit
if vars.ActiveCursorStmtID != 0 {
if !cc.isAllowedWithActiveCursor(cmd, data) {
return errNotAllowedWithActiveCursor
}
}

dataStr := string(hack.String(data))
switch cmd {
Expand Down Expand Up @@ -2600,6 +2607,18 @@ func (cc *clientConn) handleRefresh(ctx context.Context, subCommand byte) error
return cc.writeOK(ctx)
}

// isAllowedWithActiveCursor returns whether the current command is allowed with an active cursor
func (cc *clientConn) isAllowedWithActiveCursor(cmd byte, data []byte) bool {
if cmd == mysql.ComStmtFetch || cmd == mysql.ComStmtClose {
stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
if stmtID == cc.ctx.GetSessionVars().ActiveCursorStmtID {
return true
}
}

return false
}

var _ fmt.Stringer = getLastStmtInConn{}

type getLastStmtInConn struct {
Expand Down
6 changes: 6 additions & 0 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
if err != nil {
return false, err
}
// set the active cursor
cc.ctx.GetSessionVars().ActiveCursorStmtID = stmt.ID()

return false, cc.flush(ctx)
}
defer terror.Call(rs.Close)
Expand Down Expand Up @@ -681,6 +684,9 @@ func (cc *clientConn) handleStmtClose(data []byte) (err error) {
}

stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
if stmtID == cc.ctx.GetSessionVars().ActiveCursorStmtID {
cc.ctx.GetSessionVars().ActiveCursorStmtID = 0
}
stmt := cc.ctx.GetStatement(stmtID)
if stmt != nil {
return stmt.Close()
Expand Down
48 changes: 48 additions & 0 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1799,3 +1799,51 @@ func TestExtensionChangeUser(t *testing.T) {
require.Equal(t, expectedConnInfo.Error, logInfo.Error)
require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo))
}

func TestOnlyAllowFetchCloseWithActiveCursor(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
srv := CreateMockServer(t, store)
srv.SetDomain(dom)
defer srv.Close()

ctx := context.Background()
c := CreateMockConn(t, srv)
tk := testkit.NewTestKitWithSession(t, store, c.Context().Session)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(id int)")
tk.MustExec("insert into t values (1)")

stmt, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)

// execute with cursor fetch will set ActiveCursorStmtID
err = c.Dispatch(ctx, append(
binary.LittleEndian.AppendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
))
require.NoError(t, err)
require.Equal(t, c.Context().Session.GetSessionVars().ActiveCursorStmtID, stmt.ID())

// execute another statement is not allowed
err = c.Dispatch(ctx, append(
binary.LittleEndian.AppendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID()+1)),
mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0,
))
require.ErrorIs(t, err, errNotAllowedWithActiveCursor)

// close another statement is also not allowed
err = c.Dispatch(ctx, append(
binary.LittleEndian.AppendUint32([]byte{mysql.ComStmtClose}, uint32(stmt.ID()+1)),
))
require.ErrorIs(t, err, errNotAllowedWithActiveCursor)

// close this statement is allowed
err = c.Dispatch(ctx, append(
binary.LittleEndian.AppendUint32([]byte{mysql.ComStmtClose}, uint32(stmt.ID())),
))
require.NoError(t, err)

// after closing the statement, other statement is allowed
tk.MustQuery("select * from t").Check(testkit.Rows("1"))
}
27 changes: 14 additions & 13 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,20 @@ func init() {
}

var (
errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType)
errInvalidSequence = dbterror.ClassServer.NewStd(errno.ErrInvalidSequence)
errInvalidType = dbterror.ClassServer.NewStd(errno.ErrInvalidType)
errNotAllowedCommand = dbterror.ClassServer.NewStd(errno.ErrNotAllowedCommand)
errAccessDenied = dbterror.ClassServer.NewStd(errno.ErrAccessDenied)
errAccessDeniedNoPassword = dbterror.ClassServer.NewStd(errno.ErrAccessDeniedNoPassword)
errConCount = dbterror.ClassServer.NewStd(errno.ErrConCount)
errSecureTransportRequired = dbterror.ClassServer.NewStd(errno.ErrSecureTransportRequired)
errMultiStatementDisabled = dbterror.ClassServer.NewStd(errno.ErrMultiStatementDisabled)
errNewAbortingConnection = dbterror.ClassServer.NewStd(errno.ErrNewAbortingConnection)
errNotSupportedAuthMode = dbterror.ClassServer.NewStd(errno.ErrNotSupportedAuthMode)
errNetPacketTooLarge = dbterror.ClassServer.NewStd(errno.ErrNetPacketTooLarge)
errMustChangePassword = dbterror.ClassServer.NewStd(errno.ErrMustChangePassword)
errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType)
errInvalidSequence = dbterror.ClassServer.NewStd(errno.ErrInvalidSequence)
errInvalidType = dbterror.ClassServer.NewStd(errno.ErrInvalidType)
errNotAllowedCommand = dbterror.ClassServer.NewStd(errno.ErrNotAllowedCommand)
errAccessDenied = dbterror.ClassServer.NewStd(errno.ErrAccessDenied)
errAccessDeniedNoPassword = dbterror.ClassServer.NewStd(errno.ErrAccessDeniedNoPassword)
errConCount = dbterror.ClassServer.NewStd(errno.ErrConCount)
errSecureTransportRequired = dbterror.ClassServer.NewStd(errno.ErrSecureTransportRequired)
errMultiStatementDisabled = dbterror.ClassServer.NewStd(errno.ErrMultiStatementDisabled)
errNewAbortingConnection = dbterror.ClassServer.NewStd(errno.ErrNewAbortingConnection)
errNotSupportedAuthMode = dbterror.ClassServer.NewStd(errno.ErrNotSupportedAuthMode)
errNetPacketTooLarge = dbterror.ClassServer.NewStd(errno.ErrNetPacketTooLarge)
errMustChangePassword = dbterror.ClassServer.NewStd(errno.ErrMustChangePassword)
errNotAllowedWithActiveCursor = dbterror.ClassServer.NewStd(errno.ErrNotAllowedWithActiveCursor)
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down
3 changes: 3 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,9 @@ type SessionVars struct {

// ProtectedTSList holds a list of timestamps that should delay GC.
ProtectedTSList protectedTSList

// ActiveCursorStmtID indicates the stmtID of the active cursor. If it's 0, there is no active cursor
ActiveCursorStmtID int
}

// planReplayerSessionFinishedTaskKeyLen is used to control the max size for the finished plan replayer task key in session
Expand Down