Skip to content

Commit

Permalink
Refactor DataTypeOf
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jan 29, 2016
1 parent d92c5db commit 2dfd76d
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 66 deletions.
2 changes: 1 addition & 1 deletion dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type Dialect interface {
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
Quote(key string) string
// DataTypeOf return data's sql type
DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
DataTypeOf(value reflect.Value, tagSettings map[string]string) string

// HasIndex check has index or not
HasIndex(scope *Scope, tableName string, indexName string) bool
Expand Down
20 changes: 13 additions & 7 deletions dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gorm
import (
"fmt"
"reflect"
"strconv"
"time"
)

Expand All @@ -16,17 +17,22 @@ func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}

func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
func (commonDialect) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}

switch dataValue.Kind() {
case reflect.Bool:
return "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "INTEGER AUTO_INCREMENT"
}
return "INTEGER"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "BIGINT AUTO_INCREMENT"
}
return "BIGINT"
Expand All @@ -38,18 +44,18 @@ func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool
}
return "VARCHAR(65532)"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
if _, ok := dataValue.Interface().(time.Time); ok {
return "TIMESTAMP"
}
default:
if _, ok := value.Interface().([]byte); ok {
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("BINARY(%d)", size)
}
return "BINARY(65532)"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
}

func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
Expand Down
20 changes: 13 additions & 7 deletions dialect_mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,30 @@ package gorm
import (
"fmt"
"reflect"
"strconv"
"time"
)

type mssql struct {
commonDialect
}

func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
func (mssql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}

switch dataValue.Kind() {
case reflect.Bool:
return "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "int IDENTITY(1,1)"
}
return "int"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigint IDENTITY(1,1)"
}
return "bigint"
Expand All @@ -32,18 +38,18 @@ func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
if _, ok := dataValue.Interface().(time.Time); ok {
return "datetime2"
}
default:
if _, ok := value.Interface().([]byte); ok {
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "text"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
}

func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
Expand Down
24 changes: 15 additions & 9 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gorm
import (
"fmt"
"reflect"
"strconv"
"time"
)

Expand All @@ -14,27 +15,32 @@ func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}

func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
func (mysql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}

switch dataValue.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "int AUTO_INCREMENT"
}
return "int"
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "int unsigned AUTO_INCREMENT"
}
return "int unsigned"
case reflect.Int64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigint AUTO_INCREMENT"
}
return "bigint"
case reflect.Uint64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigint unsigned AUTO_INCREMENT"
}
return "bigint unsigned"
Expand All @@ -46,18 +52,18 @@ func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
}
return "longtext"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
if _, ok := dataValue.Interface().(time.Time); ok {
return "timestamp NULL"
}
default:
if _, ok := value.Interface().([]byte); ok {
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
}
return "longblob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
}

func (s mysql) currentDatabase(scope *Scope) (name string) {
Expand Down
24 changes: 15 additions & 9 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"strings"
"time"

Expand All @@ -19,17 +20,22 @@ func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i)
}

func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
func (postgres) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}

switch dataValue.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigserial"
}
return "bigint"
Expand All @@ -41,21 +47,21 @@ func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) str
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
if _, ok := dataValue.Interface().(time.Time); ok {
return "timestamp with time zone"
}
case reflect.Map:
if value.Type() == hstoreType {
if dataValue.Type() == hstoreType {
return "hstore"
}
default:
if isByteArrayOrSlice(value) {
if isByteArrayOrSlice(dataValue) {
return "bytea"
} else if isUUID(value) {
} else if isUUID(dataValue) {
return "uuid"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
}

func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
Expand Down
20 changes: 13 additions & 7 deletions dialect_sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,30 @@ package gorm
import (
"fmt"
"reflect"
"strconv"
"time"
)

type sqlite3 struct {
commonDialect
}

func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
func (sqlite3) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}

switch dataValue.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "integer primary key autoincrement"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "integer primary key autoincrement"
}
return "bigint"
Expand All @@ -32,15 +38,15 @@ func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) stri
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
if _, ok := dataValue.Interface().(time.Time); ok {
return "datetime"
}
default:
if _, ok := value.Interface().([]byte); ok {
if _, ok := dataValue.Interface().([]byte); ok {
return "blob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
}

func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
Expand Down
17 changes: 1 addition & 16 deletions model_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"go/ast"
"reflect"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -541,21 +540,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
}

if sqlType == "" {
var size = 255

if value, ok := field.TagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(value)
}

v, autoIncrease := field.TagSettings["AUTO_INCREMENT"]
if field.IsPrimaryKey {
autoIncrease = true
}
if v == "FALSE" {
autoIncrease = false
}

sqlType = scope.Dialect().DataTypeOf(reflectValue, size, autoIncrease)
sqlType = scope.Dialect().DataTypeOf(reflectValue, field.TagSettings)
}

if strings.TrimSpace(additionalType) == "" {
Expand Down
12 changes: 2 additions & 10 deletions scope_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,23 +516,15 @@ func (scope *Scope) createJoinTable(field *StructField) {
for idx, fieldName := range relationship.ForeignFieldNames {
if field, ok := scope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type))
primaryKeySqlType := field.TagSettings["TYPE"]
if primaryKeySqlType == "" {
primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false)
}
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
}
}

for idx, fieldName := range relationship.AssociationForeignFieldNames {
if field, ok := toScope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type))
primaryKeySqlType := field.TagSettings["TYPE"]
if primaryKeySqlType == "" {
primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false)
}
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
}
}
Expand Down

0 comments on commit 2dfd76d

Please sign in to comment.