Skip to content

Commit

Permalink
Allow customize data type via ParseFieldStructForDialect
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jan 15, 2017
1 parent c62e9bc commit a3b8b33
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 14 deletions.
18 changes: 14 additions & 4 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,27 @@ func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}

// ParseFieldStructForDialect parse field struct for dialect
func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
// ParseFieldStructForDialect get field's sql data type
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
// Get redirected field type
var reflectType = field.Struct.Type
var (
reflectType = field.Struct.Type
dataType = field.TagSettings["TYPE"]
)

for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}

// Get redirected field value
fieldValue = reflect.Indirect(reflect.New(reflectType))

if gormDataType, ok := fieldValue.Interface().(interface {
GormDataType(Dialect) string
}); ok {
dataType = gormDataType.GormDataType(dialect)
}

// Get scanner's real value
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
Expand All @@ -102,5 +112,5 @@ func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, s
additionalType = additionalType + " DEFAULT " + value
}

return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType)
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
}
4 changes: 2 additions & 2 deletions dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}

func (commonDialect) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
func (s *commonDialect) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)

if sqlType == "" {
switch dataValue.Kind() {
Expand Down
4 changes: 2 additions & 2 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func (mysql) Quote(key string) string {
}

// Get Data Type for MySQL Dialect
func (mysql) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
func (s *mysql) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)

// MySQL allows only one auto increment column per table, and it must
// be a KEY column.
Expand Down
4 changes: 2 additions & 2 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i)
}

func (postgres) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
func (s *postgres) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)

if sqlType == "" {
switch dataValue.Kind() {
Expand Down
4 changes: 2 additions & 2 deletions dialect_sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ func (sqlite3) GetName() string {
}

// Get Data Type for Sqlite Dialect
func (sqlite3) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
func (s *sqlite3) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)

if sqlType == "" {
switch dataValue.Kind() {
Expand Down
4 changes: 2 additions & 2 deletions dialects/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func (mssql) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}

func (mssql) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field)
func (s *mssql) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)

if sqlType == "" {
switch dataValue.Kind() {
Expand Down

0 comments on commit a3b8b33

Please sign in to comment.