Skip to content

Commit

Permalink
fix: db can't reuse
Browse files Browse the repository at this point in the history
Signed-off-by: tangyang9464 <tangyang9464@163.com>
  • Loading branch information
tangyang9464 committed Jul 19, 2022
1 parent 3d2fa84 commit 368eb78
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package gormadapter

import (
"context"
"errors"
"fmt"
"runtime"
Expand Down Expand Up @@ -251,18 +252,30 @@ func NewAdapterByDB(db *gorm.DB) (*Adapter, error) {
}

func TurnOffAutoMigrate(db *gorm.DB) {
*db = *db.Set(disableMigrateKey, false)
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}

ctx = context.WithValue(ctx, disableMigrateKey, false)

*db = *db.WithContext(ctx)
}

func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}, tableName ...string) (*Adapter, error) {
*db = *db.Set(customTableKey, t)
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}

ctx = context.WithValue(ctx, customTableKey, t)

curTableName := defaultTableName
if len(tableName) > 0 {
curTableName = tableName[0]
}

return NewAdapterByDBUseTableName(db, "", curTableName)
return NewAdapterByDBUseTableName(db.WithContext(ctx), "", curTableName)
}

func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) {
Expand Down Expand Up @@ -366,14 +379,14 @@ func (a *Adapter) casbinRuleTable() func(db *gorm.DB) *gorm.DB {
}

func (a *Adapter) createTable() error {
disableMigrate, ok := a.db.Get(disableMigrateKey)
if ok && disableMigrate != nil {
disableMigrate := a.db.Statement.Context.Value(disableMigrateKey)
if disableMigrate != nil {
return nil
}

t, ok := a.db.Get(customTableKey)
t := a.db.Statement.Context.Value(customTableKey)

if ok && t != nil {
if t != nil {
return a.db.AutoMigrate(t)
}

Expand All @@ -394,9 +407,8 @@ func (a *Adapter) createTable() error {
}

func (a *Adapter) dropTable() error {
t, ok := a.db.Get(customTableKey)

if !ok || t == nil {
t := a.db.Statement.Context.Value(customTableKey)
if t == nil {
return a.db.Migrator().DropTable(a.getTableInstance())
}

Expand Down

0 comments on commit 368eb78

Please sign in to comment.