Skip to content

Commit

Permalink
Refactor based on golint
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Mar 7, 2016
1 parent 3055bad commit ec11065
Show file tree
Hide file tree
Showing 20 changed files with 185 additions and 109 deletions.
2 changes: 1 addition & 1 deletion callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"fmt"
)

// defaultCallback hold default callbacks defined by gorm
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}

// Callback contains callbacks that used when CURD objects
Expand Down
12 changes: 6 additions & 6 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ func createCallback(scope *Scope) {
returningColumn = scope.Quote(primaryField.DBName)
}

lastInsertIdReturningSuffix := scope.Dialect().LastInsertIdReturningSuffix(quotedTableName, returningColumn)
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)

if len(columns) == 0 {
scope.Raw(fmt.Sprintf(
"INSERT INTO %v DEFAULT VALUES%v%v",
quotedTableName,
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIdReturningSuffix),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
} else {
scope.Raw(fmt.Sprintf(
Expand All @@ -101,13 +101,13 @@ func createCallback(scope *Scope) {
strings.Join(columns, ","),
strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIdReturningSuffix),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
}

// execute create sql
if lastInsertIdReturningSuffix == "" || primaryField == nil {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if lastInsertIDReturningSuffix == "" || primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()

Expand All @@ -119,7 +119,7 @@ func createCallback(scope *Scope) {
}
}
} else {
if err := scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
scope.db.RowsAffected = 1
}
}
Expand Down
8 changes: 4 additions & 4 deletions callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ func queryCallback(scope *Scope) {
return
}

scope.prepareQuerySql()
scope.prepareQuerySQL()

if !scope.HasError() {
scope.db.RowsAffected = 0
if str, ok := scope.Get("gorm:query_option"); ok {
scope.Sql += addExtraSpaceIfExist(fmt.Sprint(str))
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
}

if rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close()

columns, _ := rows.Columns()
Expand All @@ -80,7 +80,7 @@ func queryCallback(scope *Scope) {
}

if scope.db.RowsAffected == 0 && !isSlice {
scope.Err(RecordNotFound)
scope.Err(ErrRecordNotFound)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Dialect interface {
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIdReturningSuffix(tableName, columnName string) string
LastInsertIDReturningSuffix(tableName, columnName string) string
}

var dialectsMap = map[string]Dialect{}
Expand Down
2 changes: 1 addition & 1 deletion dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,6 @@ func (commonDialect) SelectFromDummyTable() string {
return ""
}

func (commonDialect) LastInsertIdReturningSuffix(tableName, columnName string) string {
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
4 changes: 2 additions & 2 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ func (s postgres) currentDatabase() (name string) {
return
}

func (s postgres) LastInsertIdReturningSuffix(tableName, key string) string {
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}

func (postgres) SupportLastInsertId() bool {
func (postgres) SupportLastInsertID() bool {
return false
}

Expand Down
2 changes: 1 addition & 1 deletion dialects/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,6 @@ func (mssql) SelectFromDummyTable() string {
return ""
}

func (mssql) LastInsertIdReturningSuffix(tableName, columnName string) string {
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
16 changes: 12 additions & 4 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@ import (
)

var (
RecordNotFound = errors.New("record not found")
InvalidSql = errors.New("invalid sql")
NoValidTransaction = errors.New("no valid transaction")
CantStartTransaction = errors.New("can't start transaction")
// ErrRecordNotFound record not found, happens when you are looking up with a struct, and haven't find any matched data
ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL invalid SQL, happens when you passed invalid SQL
ErrInvalidSQL = errors.New("invalid SQL")
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction")
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
ErrCantStartTransaction = errors.New("can't start transaction")
)

type errorsInterface interface {
GetErrors() []error
}

// Errors contains all happened errors
type Errors struct {
errors []error
}

// GetErrors get all happened errors
func (errs Errors) GetErrors() []error {
return errs.errors
}

// Add add an error
func (errs *Errors) Add(err error) {
if errors, ok := err.(errorsInterface); ok {
for _, err := range errors.GetErrors() {
Expand All @@ -39,6 +46,7 @@ func (errs *Errors) Add(err error) {
}
}

// Error format happened errors
func (errs Errors) Error() string {
var errors = []string{}
for _, e := range errs.errors {
Expand Down
2 changes: 2 additions & 0 deletions field.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ import (
"reflect"
)

// Field model field definition
type Field struct {
*StructField
IsBlank bool
Field reflect.Value
}

// Set set a value to the field
func (field *Field) Set(value interface{}) (err error) {
if !field.Field.IsValid() {
return errors.New("field value not valid")
Expand Down
26 changes: 22 additions & 4 deletions join_table_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,54 @@ import (
"strings"
)

// JoinTableHandlerInterface is an interface for how to handle many2many relations
type JoinTableHandlerInterface interface {
// initialize join table handler
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
// Table return join table's table name
Table(db *DB) string
// Add create relationship in join table for source and destination
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
// Delete delete relationship in join table for sources
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
// JoinWith query with `Join` conditions
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
// SourceForeignKeys return source foreign keys
SourceForeignKeys() []JoinTableForeignKey
// DestinationForeignKeys return destination foreign keys
DestinationForeignKeys() []JoinTableForeignKey
}

// JoinTableForeignKey join table foreign key struct
type JoinTableForeignKey struct {
DBName string
AssociationDBName string
}

// JoinTableSource is a struct that contains model type and foreign keys
type JoinTableSource struct {
ModelType reflect.Type
ForeignKeys []JoinTableForeignKey
}

// JoinTableHandler default join table handler
type JoinTableHandler struct {
TableName string `sql:"-"`
Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"`
}

// SourceForeignKeys return source foreign keys
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
return s.Source.ForeignKeys
}

// DestinationForeignKeys return destination foreign keys
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
return s.Destination.ForeignKeys
}

// Setup initialize a default join table handler
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
s.TableName = tableName

Expand All @@ -61,11 +75,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
}
}

// Table return join table's table name
func (s JoinTableHandler) Table(db *DB) string {
return s.TableName
}

func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{}

for _, source := range sources {
Expand All @@ -89,9 +104,10 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
return values
}

func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
// Add create relationship in join table for source and destination
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
scope := db.NewScope("")
searchMap := s.GetSearchMap(db, source1, source2)
searchMap := s.getSearchMap(db, source, destination)

var assignColumns, binVars, conditions []string
var values []interface{}
Expand Down Expand Up @@ -120,21 +136,23 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1
return db.Exec(sql, values...).Error
}

// Delete delete relationship in join table for sources
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var (
scope = db.NewScope(nil)
conditions []string
values []interface{}
)

for key, value := range s.GetSearchMap(db, sources...) {
for key, value := range s.getSearchMap(db, sources...) {
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}

return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
}

// JoinWith query with `Join` conditions
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
var (
scope = db.NewScope(source)
Expand Down
16 changes: 9 additions & 7 deletions logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,25 @@ import (
"unicode"
)

var (
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
)

type logger interface {
Print(v ...interface{})
}

type LogWriter interface {
type logWriter interface {
Println(v ...interface{})
}

// Logger default logger
type Logger struct {
LogWriter
logWriter
}

var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}

// Format log
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)

// Print format & print log
func (logger Logger) Print(values ...interface{}) {
if len(values) > 1 {
level := values[0]
Expand Down
Loading

0 comments on commit ec11065

Please sign in to comment.