Skip to content

Commit

Permalink
Review and Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 17, 2015
1 parent 38cbff9 commit 0b32041
Show file tree
Hide file tree
Showing 18 changed files with 262 additions and 379 deletions.
15 changes: 7 additions & 8 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ import (
)

type Association struct {
Scope *Scope
PrimaryKey interface{}
PrimaryType interface{}
Column string
Error error
Field *Field
Scope *Scope
PrimaryKey interface{}
Column string
Error error
Field *Field
}

func (association *Association) setErr(err error) *Association {
Expand Down Expand Up @@ -158,11 +157,11 @@ func (association *Association) Count() int {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
countScope := scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, association.PrimaryKey)
if relationship.ForeignType != "" {
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), association.PrimaryType)
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(ToSnake(relationship.ForeignType))), scope.TableName())
}
countScope.Count(&count)
} else if relationship.Kind == "belongs_to" {
if v, err := scope.FieldValueByName(association.Column); err == nil {
if v, ok := scope.FieldByName(association.Column); ok {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
scope.db.Model("").Table(newScope.QuotedTableName()).Where(whereSql, v).Count(&count)
}
Expand Down
7 changes: 3 additions & 4 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ func Create(scope *Scope) {
var sqls, columns []string
for _, field := range scope.Fields() {
if (field.IsNormal && !field.IsPrimaryKey) || (field.IsPrimaryKey && !field.IsBlank) {
if field.DefaultValue != nil && field.IsBlank {
continue
if !field.IsBlank || field.DefaultValue == nil {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
}
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
}
}

Expand Down
13 changes: 5 additions & 8 deletions callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,16 @@ func Update(scope *Scope) {
if !scope.HasError() {
var sqls []string

updateAttrs, ok := scope.InstanceGet("gorm:update_attrs")
if ok {
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
for key, value := range updateAttrs.(map[string]interface{}) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
}
} else {
for _, field := range scope.Fields() {
if !field.IsPrimaryKey && field.IsNormal && !field.IsIgnored {
if field.DefaultValue != nil && field.IsBlank {
continue
if !field.IsPrimaryKey && field.IsNormal {
if !field.IsBlank || field.DefaultValue == nil {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
}
Expand All @@ -68,8 +66,7 @@ func Update(scope *Scope) {
}

func AfterUpdate(scope *Scope) {
_, ok := scope.Get("gorm:update_column")
if !ok {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethod("AfterUpdate")
scope.CallMethod("AfterSave")
}
Expand Down
3 changes: 2 additions & 1 deletion common_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"strings"
"time"
)

type commonDialect struct{}
Expand Down Expand Up @@ -36,7 +37,7 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
}
return "VARCHAR(65532)"
case reflect.Struct:
if value.Type() == timeType {
if _, ok := value.Interface().(time.Time); ok {
return "TIMESTAMP"
}
default:
Expand Down
3 changes: 0 additions & 3 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ package gorm
import (
"fmt"
"reflect"
"time"
)

var timeType = reflect.TypeOf(time.Time{})

type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
Expand Down
47 changes: 30 additions & 17 deletions field.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type Field struct {
Field reflect.Value
}

func (field *Field) Set(value interface{}) (err error) {
func (field *Field) Set(value interface{}) error {
if !field.Field.IsValid() {
return errors.New("field value not valid")
}
Expand All @@ -26,16 +26,26 @@ func (field *Field) Set(value interface{}) (err error) {
}

if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
scanner.Scan(value)
} else if reflect.TypeOf(value).ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflect.ValueOf(value).Convert(field.Field.Type()))
if v, ok := value.(reflect.Value); ok {
scanner.Scan(v.Interface())
} else {
scanner.Scan(value)
}
} else {
return errors.New("could not convert argument")
reflectValue, ok := value.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(value)
}

if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflectValue.Convert(field.Field.Type()))
} else {
return errors.New("could not convert argument")
}
}

field.IsBlank = isBlank(field.Field)

return
return nil
}

// Fields get value's fields
Expand All @@ -44,24 +54,27 @@ func (scope *Scope) Fields() map[string]*Field {
fields := map[string]*Field{}
structFields := scope.GetStructFields()

indirectValue := scope.IndirectValue()
isStruct := indirectValue.Kind() == reflect.Struct
for _, structField := range structFields {
fields[structField.DBName] = scope.getField(structField)
if isStruct {
fields[structField.DBName] = getField(indirectValue, structField)
} else {
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
}
}

scope.fields = fields
}
return scope.fields
}

func (scope *Scope) getField(structField *StructField) *Field {
field := Field{StructField: structField}
indirectValue := scope.IndirectValue()
if len(structField.Names) > 0 && indirectValue.Kind() == reflect.Struct {
for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
}
field.Field = indirectValue
func getField(indirectValue reflect.Value, structField *StructField) *Field {
field := &Field{StructField: structField}
for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
}
field.Field = indirectValue
field.IsBlank = isBlank(indirectValue)
return &field
return field
}
14 changes: 7 additions & 7 deletions logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@ var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
// Format log
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)

func (logger Logger) Print(v ...interface{}) {
if len(v) > 1 {
level := v[0]
func (logger Logger) Print(values ...interface{}) {
if len(values) > 1 {
level := values[0]
currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
source := fmt.Sprintf("\033[35m(%v)\033[0m", v[1])
source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
messages := []interface{}{source, currentTime}

if level == "sql" {
// duration
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(v[2].(time.Duration).Nanoseconds()/1e4)/100.0))
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
// sql
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(v[3].(string), "'%v'"), v[4].([]interface{})...))
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "'%v'"), values[4].([]interface{})...))
} else {
messages = append(messages, "\033[31;1m")
messages = append(messages, v[2:]...)
messages = append(messages, values[2:]...)
messages = append(messages, "\033[0m")
}
logger.Println(messages...)
Expand Down
Loading

0 comments on commit 0b32041

Please sign in to comment.