From d7455fa5b1aa38c6f2df7483263730dd10473d95 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 14 Feb 2016 13:34:32 +0800 Subject: [PATCH] Refactor DataTypeOf for mysql --- dialect_mysql.go | 107 ++++++++++++++++++++++++--------------------- dialect_sqlite3.go | 4 +- 2 files changed, 59 insertions(+), 52 deletions(-) diff --git a/dialect_mysql.go b/dialect_mysql.go index 15849abc4..22f8b88de 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -3,7 +3,7 @@ package gorm import ( "fmt" "reflect" - "strconv" + "strings" "time" ) @@ -15,60 +15,69 @@ func (mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } +// Get Data Type for MySQL Dialect func (mysql) DataTypeOf(field *StructField) string { - var ( - size int - dataValue = reflect.Indirect(reflect.New(field.Struct.Type)) - tagSettings = field.TagSettings - ) + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) - if num, ok := tagSettings["SIZE"]; ok { - size, _ = strconv.Atoi(num) - } - - switch dataValue.Kind() { - case reflect.Bool: - return "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: - if _, ok := tagSettings["AUTO_INCREMENT"]; ok { - return "int AUTO_INCREMENT" - } - return "int" - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := tagSettings["AUTO_INCREMENT"]; ok { - return "int unsigned AUTO_INCREMENT" - } - return "int unsigned" - case reflect.Int64: - if _, ok := tagSettings["AUTO_INCREMENT"]; ok { - return "bigint AUTO_INCREMENT" - } - return "bigint" - case reflect.Uint64: - if _, ok := tagSettings["AUTO_INCREMENT"]; ok { - return "bigint unsigned AUTO_INCREMENT" - } - return "bigint unsigned" - case reflect.Float32, reflect.Float64: - return "double" - case reflect.String: - if size > 0 && size < 65532 { - return fmt.Sprintf("varchar(%d)", size) - } - return "longtext" - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - return "timestamp NULL" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "int AUTO_INCREMENT" + } else { + sqlType = "int" + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "int unsigned AUTO_INCREMENT" + } else { + sqlType = "int unsigned" + } + case reflect.Int64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "bigint AUTO_INCREMENT" + } else { + sqlType = "bigint" + } + case reflect.Uint64: + if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + sqlType = "bigint unsigned AUTO_INCREMENT" + } else { + sqlType = "bigint unsigned" + } + case reflect.Float32, reflect.Float64: + sqlType = "double" + case reflect.String: if size > 0 && size < 65532 { - return fmt.Sprintf("varbinary(%d)", size) + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "longtext" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "timestamp NULL" + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varbinary(%d)", size) + } else { + sqlType = "longblob" + } } - return "longblob" } } - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) + + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) + } + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) } func (s mysql) currentDatabase(scope *Scope) (name string) { diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 0bf2aa8ca..d5ffb78d8 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -13,9 +13,7 @@ type sqlite3 struct { // Get Data Type for Sqlite Dialect func (sqlite3) DataTypeOf(field *StructField) string { - var ( - dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) - ) + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) if sqlType == "" { switch dataValue.Kind() {