Skip to content
26 changes: 16 additions & 10 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,17 @@ func (e *Engine) QueryNodeWithBindings(
}
}

_, err = e.beginTransaction(ctx, parsed)
// Before we begin a transaction, we need to know if the database being operated on is not the one
// currently selected
transactionDatabase := analyzer.GetTransactionDatabase(ctx, parsed)

// This validates that we have valid working set and branch regardless of autocommit status
err = ctx.Session.ValidateSession(ctx, transactionDatabase)
if err != nil {
return nil, nil, err
}

err = e.beginTransaction(ctx, transactionDatabase)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -449,11 +459,7 @@ func init() {
}
}

func (e *Engine) beginTransaction(ctx *sql.Context, parsed sql.Node) (string, error) {
// Before we begin a transaction, we need to know if the database being operated on is not the one
// currently selected
transactionDatabase := analyzer.GetTransactionDatabase(ctx, parsed)

func (e *Engine) beginTransaction(ctx *sql.Context, transactionDatabase string) error {
// TODO: this won't work with transactions that cross database boundaries, we need to detect that and error out
beginNewTransaction := ctx.GetTransaction() == nil || plan.ReadCommitted(ctx)
if beginNewTransaction {
Expand All @@ -462,9 +468,9 @@ func (e *Engine) beginTransaction(ctx *sql.Context, parsed sql.Node) (string, er
database, err := e.Analyzer.Catalog.Database(ctx, transactionDatabase)
// if the database doesn't exist, just don't start a transaction on it, let other layers complain
if sql.ErrDatabaseNotFound.Is(err) || sql.ErrDatabaseAccessDeniedForUser.Is(err) {
return "", nil
return nil
} else if err != nil {
return "", err
return err
}

if privilegedDatabase, ok := database.(mysql_db.PrivilegedDatabase); ok {
Expand All @@ -474,15 +480,15 @@ func (e *Engine) beginTransaction(ctx *sql.Context, parsed sql.Node) (string, er
if ok {
tx, err := tdb.StartTransaction(ctx, sql.ReadWrite)
if err != nil {
return "", err
return err
}

ctx.SetTransaction(tx)
}
}
}

return transactionDatabase, nil
return nil
}

func (e *Engine) Close() error {
Expand Down
19 changes: 19 additions & 0 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -5716,6 +5716,25 @@ func TestPersist(t *testing.T, harness Harness, newPersistableSess func(ctx *sql
}
}

func TestValidateSession(t *testing.T, harness Harness, newSessFunc func(ctx *sql.Context) sql.PersistableSession, count *int) {
queries := []string{"SHOW TABLES;", "SELECT i from mytable;"}
harness.Setup(setup.MydbData, setup.MytableData)
e := mustNewEngine(t, harness)
defer e.Close()

sql.InitSystemVariables()
ctx := NewContext(harness)
ctx.Session = newSessFunc(ctx)

for _, q := range queries {
t.Run("test running queries to check callbacks on ValidateSession()", func(t *testing.T) {
RunQueryWithContext(t, e, harness, ctx, q)
})
}
// This asserts that ValidateSession() method was called once for every statement.
require.Equal(t, len(queries), *count)
}

func TestPrepared(t *testing.T, harness Harness) {
qtests := []queries.QueryTest{
{
Expand Down
13 changes: 13 additions & 0 deletions enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,19 @@ func TestPersist(t *testing.T) {
enginetest.TestPersist(t, enginetest.NewDefaultMemoryHarness(), newSess)
}

func TestValidateSession(t *testing.T) {
count := 0
incrementValidateCb := func() {
count++
}

newSess := func(ctx *sql.Context) sql.PersistableSession {
sess := memory.NewInMemoryPersistedSessionWithValidationCallback(ctx.Session, incrementValidateCb)
return sess
}
enginetest.TestValidateSession(t, enginetest.NewDefaultMemoryHarness(), newSess, &count)
}

func TestPrepared(t *testing.T) {
enginetest.TestPrepared(t, enginetest.NewDefaultMemoryHarness())
}
Expand Down
16 changes: 15 additions & 1 deletion memory/persisted_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ type GlobalsMap = map[string]interface{}
type InMemoryPersistedSession struct {
sql.Session
persistedGlobals GlobalsMap
validateCallback func()
}

// NewInMemoryPersistedSession is a sql.PersistableSession that writes global variables to an im-memory map
func NewInMemoryPersistedSession(sess sql.Session, persistedGlobals GlobalsMap) *InMemoryPersistedSession {
return &InMemoryPersistedSession{Session: sess, persistedGlobals: persistedGlobals}
}

// NewInMemoryPersistedSessionWithValidationCallback is a sql.PersistableSession that defines increment function to count number of calls on ValidateSession().
func NewInMemoryPersistedSessionWithValidationCallback(sess sql.Session, validateCb func()) *InMemoryPersistedSession {
return &InMemoryPersistedSession{Session: sess, validateCallback: validateCb}
}

// PersistGlobal implements sql.PersistableSession
func (s *InMemoryPersistedSession) PersistGlobal(sysVarName string, value interface{}) error {
sysVar, _, ok := sql.SystemVariables.GetGlobal(sysVarName)
Expand Down Expand Up @@ -56,7 +62,15 @@ func (s *InMemoryPersistedSession) RemoveAllPersistedGlobals() error {
return nil
}

// RemoveAllPersistedGlobals implements sql.PersistableSession
// GetPersistedValue implements sql.PersistableSession
func (s *InMemoryPersistedSession) GetPersistedValue(k string) (interface{}, error) {
return s.persistedGlobals[k], nil
}

// ValidateSession counts the number of times this method is called.
func (s *InMemoryPersistedSession) ValidateSession(ctx *sql.Context, dbName string) error {
if s.validateCallback != nil {
s.validateCallback()
}
return s.Session.ValidateSession(ctx, dbName)
}
15 changes: 11 additions & 4 deletions sql/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,20 @@ type Session interface {
GetCharacterSetResults() CharacterSetID
// GetCollation returns the collation for this session (defined by the system variable `collation_connection`).
GetCollation() CollationID
// ValidateSession provides integrators a chance to do any custom validation of this session before any query is executed in it. For example, Dolt uses this hook to validate that the session's working set is valid.
ValidateSession(ctx *Context, dbName string) error
}

// PersistableSession supports serializing/deserializing global system variables/
type PersistableSession interface {
Session
// PersistGlobal writes to the persisted global system variables file
PersistGlobal(sysVarName string, value interface{}) error
// RemovePersisted deletes a variable from the persisted globals file
// RemovePersistedGlobal deletes a variable from the persisted globals file
RemovePersistedGlobal(sysVarName string) error
// RemoveAllPersisted clears the contents of the persisted globals file
// RemoveAllPersistedGlobals clears the contents of the persisted globals file
RemoveAllPersistedGlobals() error
// GetPersistedValue
// GetPersistedValue returns persisted value for a global system variable
GetPersistedValue(k string) (interface{}, error)
}

Expand Down Expand Up @@ -224,7 +226,7 @@ func (s *BaseSession) Address() string { return s.addr }
// Client returns session's client information.
func (s *BaseSession) Client() Client { return s.client }

// WithClient implements Session.
// SetClient implements the Session interface.
func (s *BaseSession) SetClient(c Client) {
s.client = c
return
Expand Down Expand Up @@ -364,6 +366,11 @@ func (s *BaseSession) GetCollation() CollationID {
return collation
}

// ValidateSession provides integrators a chance to do any custom validation of this session before any query is executed in it.
func (s *BaseSession) ValidateSession(ctx *Context, dbName string) error {
return nil
}

// GetCurrentDatabase gets the current database for this session
func (s *BaseSession) GetCurrentDatabase() string {
s.mu.RLock()
Expand Down