Skip to content

Commit

Permalink
Refactor dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jan 18, 2016
1 parent 896ee53 commit e159ca1
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 137 deletions.
16 changes: 9 additions & 7 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ import (

type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
HasTop() bool
SqlTag(value reflect.Value, size int, autoIncrease bool) string
ReturningStr(tableName, key string) string
SelectFromDummyTable() string
Quote(key string) string
HasTable(scope *Scope, tableName string) bool
HasColumn(scope *Scope, tableName string, columnName string) bool
SqlTag(value reflect.Value, size int, autoIncrease bool) string

HasIndex(scope *Scope, tableName string, indexName string) bool
RemoveIndex(scope *Scope, indexName string)
HasTable(scope *Scope, tableName string) bool
HasColumn(scope *Scope, tableName string, columnName string) bool
CurrentDatabase(scope *Scope) string

ReturningStr(tableName, key string) string
LimitAndOffsetSQL(limit, offset int) string
SelectFromDummyTable() string
SupportLastInsertId() bool
}

func NewDialect(driver string) Dialect {
Expand Down
60 changes: 33 additions & 27 deletions dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@ func (commonDialect) BinVar(i int) string {
return "$$" // ?
}

func (commonDialect) SupportLastInsertId() bool {
return true
}

func (commonDialect) HasTop() bool {
return false
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}

func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
Expand Down Expand Up @@ -56,16 +52,17 @@ func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) st
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
}

func (commonDialect) ReturningStr(tableName, key string) string {
return ""
}

func (commonDialect) SelectFromDummyTable() string {
return ""
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var (
count int
databaseName = c.CurrentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
return count > 0
}

func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
}

func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
Expand All @@ -86,19 +83,6 @@ func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName stri
return count > 0
}

func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var (
count int
databaseName = c.CurrentDatabase(scope)
)
c.RawScanInt(scope, &count, "SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", databaseName, tableName, indexName)
return count > 0
}

func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())).Error)
}

// RawScanInt scans the first column of the first row into the `scan' int pointer.
// This function captures raw query errors and propagates them to the original scope.
func (commonDialect) RawScanInt(scope *Scope, scanPtr *int, query string, args ...interface{}) {
Expand All @@ -115,3 +99,25 @@ func (commonDialect) CurrentDatabase(scope *Scope) (name string) {
scope.Err(scope.NewDB().Raw("SELECT DATABASE()").Row().Scan(&name))
return
}

func (commonDialect) ReturningStr(tableName, key string) string {
return ""
}

func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) {
if limit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", limit)
}
if offset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", offset)
}
return
}

func (commonDialect) SelectFromDummyTable() string {
return ""
}

func (commonDialect) SupportLastInsertId() bool {
return true
}
33 changes: 23 additions & 10 deletions dialect_mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ type mssql struct {
commonDialect
}

func (mssql) HasTop() bool {
return true
}

func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
Expand Down Expand Up @@ -50,6 +46,12 @@ func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
}

func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
return count > 0
}

func (s mssql) HasTable(scope *Scope, tableName string) bool {
var (
count int
Expand All @@ -68,13 +70,24 @@ func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool
return count > 0
}

func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName)
return count > 0
}

func (s mssql) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DB_NAME() AS [Current Database]")
return
}

func (mssql) LimitAndOffsetSQL(limit, offset int) (sql string) {
if limit < 0 && offset < 0 {
return
}

if offset < 0 {
offset = 0
}

sql += fmt.Sprintf(" OFFSET %d ROWS", offset)

if limit >= 0 {
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", limit)
}
return
}
14 changes: 7 additions & 7 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ type mysql struct {
commonDialect
}

func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}

func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
Expand Down Expand Up @@ -56,15 +60,11 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
}

func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
func (s mysql) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DATABASE()")
return
}

func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}

func (s mysql) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT DATABASE()")
return
}
54 changes: 26 additions & 28 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ func (postgres) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}

func (postgres) SupportLastInsertId() bool {
return false
}

func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
Expand Down Expand Up @@ -62,23 +58,14 @@ func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
}

var byteType = reflect.TypeOf(uint8(0))

func isByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == byteType
}

func isUUID(value reflect.Value) bool {
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
return false
}
typename := value.Type().Name()
lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
return count > 0
}

func (s postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}

func (s postgres) HasTable(scope *Scope, tableName string) bool {
Expand All @@ -93,19 +80,17 @@ func (s postgres) HasColumn(scope *Scope, tableName string, columnName string) b
return count > 0
}

func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
func (s postgres) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
return
}

func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName)
return count > 0
func (s postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}

func (s postgres) CurrentDatabase(scope *Scope) (name string) {
s.RawScanString(scope, &name, "SELECT CURRENT_DATABASE()")
return
func (postgres) SupportLastInsertId() bool {
return false
}

var hstoreType = reflect.TypeOf(Hstore{})
Expand Down Expand Up @@ -152,3 +137,16 @@ func (h *Hstore) Scan(value interface{}) error {

return nil
}

func isByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}

func isUUID(value reflect.Value) bool {
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
return false
}
typename := value.Type().Name()
lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower
}
20 changes: 10 additions & 10 deletions dialect_sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,26 @@ func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
}

func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
return count > 0
}

func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName)
return count > 0
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
}

func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (s sqlite3) HasTable(scope *Scope, tableName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
s.RawScanInt(scope, &count, "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName)
return count > 0
}

func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.Err(scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)).Error)
func (s sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%, \"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%' OR sql LIKE '%%,%v %%');\n", columnName, columnName, columnName, columnName, columnName, columnName), tableName)
return count > 0
}

func (sqlite3) CurrentDatabase(scope *Scope) (name string) {
Expand Down
8 changes: 4 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ func (s *DB) Not(query interface{}, args ...interface{}) *DB {
return s.clone().search.Not(query, args...).db
}

func (s *DB) Limit(value interface{}) *DB {
return s.clone().search.Limit(value).db
func (s *DB) Limit(limit int) *DB {
return s.clone().search.Limit(limit).db
}

func (s *DB) Offset(value interface{}) *DB {
return s.clone().search.Offset(value).db
func (s *DB) Offset(offset int) *DB {
return s.clone().search.Offset(offset).db
}

func (s *DB) Order(value string, reorder ...bool) *DB {
Expand Down
2 changes: 1 addition & 1 deletion main_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func (s *DB) clone() *DB {
}

if s.search == nil {
db.search = &search{}
db.search = &search{limit: -1, offset: -1}
} else {
db.search = s.search.clone()
}
Expand Down
2 changes: 1 addition & 1 deletion scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func (scope *Scope) QuotedTableName() (name string) {
// CombinedConditionSql get combined condition sql
func (scope *Scope) CombinedConditionSql() string {
return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
scope.havingSql() + scope.orderSql() + scope.limitAndOffsetSql()
}

// FieldByName find gorm.Field with name and db name
Expand Down
Loading

0 comments on commit e159ca1

Please sign in to comment.