Skip to content

Commit

Permalink
Refactor DataTypeOf for sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 14, 2016
1 parent dc435d2 commit 552d9bf
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 91 deletions.
42 changes: 41 additions & 1 deletion dialect.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package gorm

import (
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"
)

// Dialect interface contains behaviors that differ across SQL database
Expand All @@ -12,7 +15,7 @@ type Dialect interface {
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
Quote(key string) string
// DataTypeOf return data's sql type
DataTypeOf(value reflect.Value, tagSettings map[string]string) string
DataTypeOf(field *StructField) string

// HasIndex check has index or not
HasIndex(scope *Scope, tableName string, indexName string) bool
Expand Down Expand Up @@ -48,3 +51,40 @@ func NewDialect(driver string) Dialect {
}
return d
}

// ParseFieldStructForDialect parse field struct for dialect
func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
// Get redirected field type
var reflectType = field.Struct.Type
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}

// Get redirected field value
fieldValue = reflect.Indirect(reflect.New(reflectType))

// Get scanner's real value
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
fieldValue = value
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
getScannerValue(fieldValue.Field(0))
}
}
getScannerValue(fieldValue)

// Default Size
if num, ok := field.TagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
} else {
size = 255
}

// Default type from tag setting
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = additionalType + " DEFAULT " + value
}

return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType)
}
9 changes: 7 additions & 2 deletions dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}

func (commonDialect) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
func (commonDialect) DataTypeOf(field *StructField) string {
var (
size int
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
tagSettings = field.TagSettings
)

if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
Expand Down
9 changes: 7 additions & 2 deletions dialect_mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ type mssql struct {
commonDialect
}

func (mssql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
func (mssql) DataTypeOf(field *StructField) string {
var (
size int
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
tagSettings = field.TagSettings
)

if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
Expand Down
9 changes: 7 additions & 2 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@ func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}

func (mysql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
func (mysql) DataTypeOf(field *StructField) string {
var (
size int
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
tagSettings = field.TagSettings
)

if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
Expand Down
9 changes: 7 additions & 2 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i)
}

func (postgres) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
func (postgres) DataTypeOf(field *StructField) string {
var (
size int
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
tagSettings = field.TagSettings
)

if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
Expand Down
81 changes: 47 additions & 34 deletions dialect_sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,63 @@ package gorm
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
)

type sqlite3 struct {
commonDialect
}

func (sqlite3) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
// Get Data Type for Sqlite Dialect
func (sqlite3) DataTypeOf(field *StructField) string {
var (
dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
)

switch dataValue.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "integer primary key autoincrement"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "integer primary key autoincrement"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "text"
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
return "datetime"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
return "blob"
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if field.IsPrimaryKey {
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if field.IsPrimaryKey {
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "real"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetime"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
sqlType = "blob"
}
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))

if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
}

if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
Expand Down
39 changes: 0 additions & 39 deletions model_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package gorm
import (
"database/sql"
"errors"
"fmt"
"go/ast"
"reflect"
"strings"
Expand Down Expand Up @@ -511,44 +510,6 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
return scope.GetModelStruct().StructFields
}

func (scope *Scope) generateSqlTag(field *StructField) string {
var sqlType string
structType := field.Struct.Type
if structType.Kind() == reflect.Ptr {
structType = structType.Elem()
}
reflectValue := reflect.Indirect(reflect.New(structType))

if value, ok := field.TagSettings["TYPE"]; ok {
sqlType = value
}

additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = additionalType + " DEFAULT " + value
}

if field.IsScanner {
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
reflectValue = value
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
getScannerValue(reflectValue.Field(0))
}
}
getScannerValue(reflectValue)
}

if sqlType == "" {
sqlType = scope.Dialect().DataTypeOf(reflectValue, field.TagSettings)
}

if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}

func parseTagSetting(tags reflect.StructTag) map[string]string {
setting := map[string]string{}
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
Expand Down
22 changes: 13 additions & 9 deletions scope_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ func (scope *Scope) getTableOptions() string {
return tableOptions.(string)
}

func (scope *Scope) createJoinTable(field *StructField) {
func (scope *Scope) createJoinTable(field *Field) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
Expand All @@ -521,16 +521,20 @@ func (scope *Scope) createJoinTable(field *StructField) {
var sqlTypes, primaryKeys []string
for idx, fieldName := range relationship.ForeignFieldNames {
if field, ok := scope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type))
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
foreignKeyStruct := field.StructField.clone()
foreignKeyStruct.IsPrimaryKey = false
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
}
}

for idx, fieldName := range relationship.AssociationForeignFieldNames {
if field, ok := toScope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type))
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
foreignKeyStruct := field.StructField.clone()
foreignKeyStruct.IsPrimaryKey = false
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
}
}
Expand All @@ -545,9 +549,9 @@ func (scope *Scope) createTable() *Scope {
var tags []string
var primaryKeys []string
var primaryKeyInColumnType = false
for _, field := range scope.GetStructFields() {
for _, field := range scope.Fields() {
if field.IsNormal {
sqlTag := scope.generateSqlTag(field)
sqlTag := scope.Dialect().DataTypeOf(field.StructField)

// Check if the primary key constraint was specified as
// part of the column type. If so, we can only support
Expand Down Expand Up @@ -632,10 +636,10 @@ func (scope *Scope) autoMigrate() *Scope {
if !scope.Dialect().HasTable(scope, tableName) {
scope.createTable()
} else {
for _, field := range scope.GetStructFields() {
for _, field := range scope.Fields() {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.generateSqlTag(field)
sqlTag := scope.Dialect().DataTypeOf(field.StructField)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
}
}
Expand Down

0 comments on commit 552d9bf

Please sign in to comment.