Skip to content

Commit

Permalink
Refactor dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 15, 2016
1 parent 6546ec3 commit 4e8370e
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 104 deletions.
2 changes: 1 addition & 1 deletion customize_column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestCustomizeColumn(t *testing.T) {
DB.AutoMigrate(&CustomizeColumn{})

scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
if !scope.Dialect().HasColumn(scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}

Expand Down
3 changes: 1 addition & 2 deletions ddl_errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ func TestDdlErrors(t *testing.T) {
}
}()

DB.HasTable("foobarbaz")
if DB.Error == nil {
if err := DB.Find(&User{}).Error; err == nil {
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
}
}
22 changes: 15 additions & 7 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (

// Dialect interface contains behaviors that differ across SQL database
type Dialect interface {
// SetDB set db for dialect
SetDB(db *sql.DB)

// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
BindVar(i int) string
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
Expand All @@ -18,13 +21,13 @@ type Dialect interface {
DataTypeOf(field *StructField) string

// HasIndex check has index or not
HasIndex(scope *Scope, tableName string, indexName string) bool
HasIndex(tableName string, indexName string) bool
// RemoveIndex remove index
RemoveIndex(scope *Scope, indexName string)
RemoveIndex(tableName string, indexName string) error
// HasTable check has table or not
HasTable(scope *Scope, tableName string) bool
HasTable(tableName string) bool
// HasColumn check has column or not
HasColumn(scope *Scope, tableName string, columnName string) bool
HasColumn(tableName string, columnName string) bool

// LimitAndOffsetSQL return generate SQL with limit and offset, as mssql has special case
LimitAndOffsetSQL(limit, offset int) string
Expand All @@ -36,12 +39,17 @@ type Dialect interface {

var dialectsMap = map[string]Dialect{}

func newDialect(name string) Dialect {
if dialect, ok := dialectsMap[name]; ok {
func newDialect(name string, db *sql.DB) Dialect {
if value, ok := dialectsMap[name]; ok {
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
dialect.SetDB(db)
return dialect
}

fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
return &commonDialect{}
commontDialect := &commonDialect{}
commontDialect.SetDB(db)
return commontDialect
}

// RegisterDialect register new dialect
Expand Down
57 changes: 22 additions & 35 deletions dialect_common.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
package gorm

import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
)

type commonDialect struct{}
type commonDialect struct {
db *sql.DB
}

func init() {
RegisterDialect("common", &commonDialect{})
}

func (s *commonDialect) SetDB(db *sql.DB) {
s.db = db
}

func (commonDialect) BindVar(i int) string {
return "$$" // ?
}
Expand Down Expand Up @@ -73,51 +80,31 @@ func (commonDialect) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var (
count int
databaseName = c.currentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count)
return count > 0
}

func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
return err
}

func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
var (
count int
databaseName = c.currentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", databaseName, tableName)
func (s commonDialect) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count)
return count > 0
}

func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var (
count int
databaseName = c.currentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}

// RawScanInt scans the first column of the first row into the `scan' int pointer.
// This function captures raw query errors and propagates them to the original scope.
func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
}

// RawScanString scans the first column of the first row into the `scan' string pointer.
// This function captures raw query errors and propagates them to the original scope.
func (commonDialect) RawScanString(scope *Scope, scanPtr *string, query string, args ...interface{}) {
scope.Err(scope.NewDB().Raw(query, args...).Row().Scan(scanPtr))
}

func (commonDialect) currentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
func (s commonDialect) currentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}

Expand Down
31 changes: 15 additions & 16 deletions dialect_mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,32 +67,31 @@ func (mssql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
return count > 0
}

func (s mssql) HasTable(scope *Scope, tableName string) bool {
var (
count int
databaseName = s.currentDatabase(scope)
)
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, databaseName)
func (s mssql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}

func (s mssql) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.currentDatabase()).Scan(&count)
return count > 0
}

func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var (
count int
databaseName = s.currentDatabase(scope)
)
s.RawScanInt(scope, &count, "SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", databaseName, tableName, columnName)
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count)
return count > 0
}

func (s mssql) currentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
func (s mssql) currentDatabase() (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
}

Expand Down
9 changes: 7 additions & 2 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,13 @@ func (mysql) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s mysql) currentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DATABASE()")
func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}

func (s mysql) currentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}

Expand Down
20 changes: 8 additions & 12 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,26 @@ func (postgres) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
return count > 0
}

func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}

func (s postgres) HasTable(scope *Scope, tableName string) bool {
func (s postgres) HasTable(tableName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
return count > 0
}

func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
return count > 0
}

func (s postgres) currentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
func (s postgres) currentDatabase() (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
return
}

Expand Down
20 changes: 8 additions & 12 deletions dialect_sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,29 +65,25 @@ func (sqlite3) DataTypeOf(field *StructField) string {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}

func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}

func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
func (s sqlite3) HasTable(tableName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}

func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName)
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName).Scan(&count)
return count > 0
}

func (sqlite3) currentDatabase(scope *Scope) (name string) {
func (s sqlite3) currentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
Expand All @@ -96,7 +92,7 @@ func (sqlite3) currentDatabase(scope *Scope) (name string) {
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := scope.NewDB().Raw("PRAGMA database_list").Row().Scan(ifaces...); scope.Err(err) != nil {
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
Expand Down
6 changes: 3 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
}

db = DB{
dialect: newDialect(dialect),
dialect: newDialect(dialect, dbSql.(*sql.DB)),
logger: defaultLogger,
callbacks: defaultCallback,
source: source,
Expand Down Expand Up @@ -430,7 +430,7 @@ func (s *DB) HasTable(value interface{}) bool {
tableName = scope.TableName()
}

has := scope.Dialect().HasTable(scope, tableName)
has := scope.Dialect().HasTable(tableName)
s.AddError(scope.db.Error)
return has
}
Expand Down Expand Up @@ -531,7 +531,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination)
field.Relationship.JoinTableHandler = handler
if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
if table := handler.Table(s); scope.Dialect().HasTable(table) {
s.Table(table).AutoMigrate(handler)
}
}
Expand Down
Loading

0 comments on commit 4e8370e

Please sign in to comment.