diff --git a/adapter.go b/adapter.go index c1ee251..f32f007 100755 --- a/adapter.go +++ b/adapter.go @@ -15,6 +15,7 @@ package gormadapter import ( + "context" "errors" "fmt" "runtime" @@ -251,18 +252,30 @@ func NewAdapterByDB(db *gorm.DB) (*Adapter, error) { } func TurnOffAutoMigrate(db *gorm.DB) { - *db = *db.Set(disableMigrateKey, false) + ctx := db.Statement.Context + if ctx == nil { + ctx = context.Background() + } + + ctx = context.WithValue(ctx, disableMigrateKey, false) + + *db = *db.WithContext(ctx) } func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}, tableName ...string) (*Adapter, error) { - *db = *db.Set(customTableKey, t) + ctx := db.Statement.Context + if ctx == nil { + ctx = context.Background() + } + + ctx = context.WithValue(ctx, customTableKey, t) curTableName := defaultTableName if len(tableName) > 0 { curTableName = tableName[0] } - return NewAdapterByDBUseTableName(db, "", curTableName) + return NewAdapterByDBUseTableName(db.WithContext(ctx), "", curTableName) } func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) { @@ -366,14 +379,14 @@ func (a *Adapter) casbinRuleTable() func(db *gorm.DB) *gorm.DB { } func (a *Adapter) createTable() error { - disableMigrate, ok := a.db.Get(disableMigrateKey) - if ok && disableMigrate != nil { + disableMigrate := a.db.Statement.Context.Value(disableMigrateKey) + if disableMigrate != nil { return nil } - t, ok := a.db.Get(customTableKey) + t := a.db.Statement.Context.Value(customTableKey) - if ok && t != nil { + if t != nil { return a.db.AutoMigrate(t) } @@ -394,9 +407,8 @@ func (a *Adapter) createTable() error { } func (a *Adapter) dropTable() error { - t, ok := a.db.Get(customTableKey) - - if !ok || t == nil { + t := a.db.Statement.Context.Value(customTableKey) + if t == nil { return a.db.Migrator().DropTable(a.getTableInstance()) }