Skip to content

Commit

Permalink
Refactor Search API
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Mar 12, 2015
1 parent 9f29599 commit 6e5d46b
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 136 deletions.
2 changes: 1 addition & 1 deletion callback_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ func BeforeDelete(scope *Scope) {

func Delete(scope *Scope) {
if !scope.HasError() {
if !scope.Search.Unscope && scope.HasColumn("DeletedAt") {
if !scope.Search.unscoped && scope.HasColumn("DeletedAt") {
scope.Raw(
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
scope.QuotedTableName(),
Expand Down
2 changes: 1 addition & 1 deletion callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func Query(scope *Scope) {

if orderBy, ok := scope.InstanceGet("gorm:order_by_primary_key"); ok {
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
scope.Search = scope.Search.clone().order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy))
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy))
}
}

Expand Down
44 changes: 22 additions & 22 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (s *DB) New() *DB {
func (db *DB) NewScope(value interface{}) *Scope {
dbClone := db.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
}

// CommonDB Return the underlying sql.DB or sql.Tx instance.
Expand Down Expand Up @@ -128,43 +128,43 @@ func (s *DB) SingularTable(enable bool) {
}

func (s *DB) Where(query interface{}, args ...interface{}) *DB {
return s.clone().search.where(query, args...).db
return s.clone().search.Where(query, args...).db
}

func (s *DB) Or(query interface{}, args ...interface{}) *DB {
return s.clone().search.or(query, args...).db
return s.clone().search.Or(query, args...).db
}

func (s *DB) Not(query interface{}, args ...interface{}) *DB {
return s.clone().search.not(query, args...).db
return s.clone().search.Not(query, args...).db
}

func (s *DB) Limit(value interface{}) *DB {
return s.clone().search.limit(value).db
return s.clone().search.Limit(value).db
}

func (s *DB) Offset(value interface{}) *DB {
return s.clone().search.offset(value).db
return s.clone().search.Offset(value).db
}

func (s *DB) Order(value string, reorder ...bool) *DB {
return s.clone().search.order(value, reorder...).db
return s.clone().search.Order(value, reorder...).db
}

func (s *DB) Select(query interface{}, args ...interface{}) *DB {
return s.clone().search.selects(query, args...).db
return s.clone().search.Selects(query, args...).db
}

func (s *DB) Group(query string) *DB {
return s.clone().search.group(query).db
return s.clone().search.Group(query).db
}

func (s *DB) Having(query string, values ...interface{}) *DB {
return s.clone().search.having(query, values...).db
return s.clone().search.Having(query, values...).db
}

func (s *DB) Joins(query string) *DB {
return s.clone().search.joins(query).db
return s.clone().search.Joins(query).db
}

func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
Expand All @@ -175,27 +175,27 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
}

func (s *DB) Unscoped() *DB {
return s.clone().search.unscoped().db
return s.clone().search.Unscoped().db
}

func (s *DB) Attrs(attrs ...interface{}) *DB {
return s.clone().search.attrs(attrs...).db
return s.clone().search.Attrs(attrs...).db
}

func (s *DB) Assign(attrs ...interface{}) *DB {
return s.clone().search.assign(attrs...).db
return s.clone().search.Assign(attrs...).db
}

func (s *DB) First(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope.Search = newScope.Search.clone().limit(1)
newScope.Search.Limit(1)
return newScope.InstanceSet("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}

func (s *DB) Last(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope.Search = newScope.Search.clone().limit(1)
newScope.Search.Limit(1)
return newScope.InstanceSet("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
Expand Down Expand Up @@ -226,7 +226,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
}
c.NewScope(out).inlineCondition(where...).initialize()
} else {
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.AssignAttrs), false)
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.assignAttrs), false)
}
return c
}
Expand All @@ -238,8 +238,8 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
return result
}
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
} else if len(c.search.AssignAttrs) > 0 {
c.NewScope(out).InstanceSet("gorm:update_interface", s.search.AssignAttrs).callCallbacks(s.parent.callback.updates)
} else if len(c.search.assignAttrs) > 0 {
c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates)
}
return c
}
Expand Down Expand Up @@ -284,7 +284,7 @@ func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
}

func (s *DB) Raw(sql string, values ...interface{}) *DB {
return s.clone().search.raw(true).where(sql, values...).db
return s.clone().search.Raw(true).Where(sql, values...).db
}

func (s *DB) Exec(sql string, values ...interface{}) *DB {
Expand Down Expand Up @@ -315,7 +315,7 @@ func (s *DB) Count(value interface{}) *DB {

func (s *DB) Table(name string) *DB {
clone := s.clone()
clone.search.table(name)
clone.search.Table(name)
clone.Value = nil
return clone
}
Expand Down Expand Up @@ -447,7 +447,7 @@ func (s *DB) Association(column string) *Association {
}

func (s *DB) Preload(column string, conditions ...interface{}) *DB {
return s.clone().search.preload(column, conditions...).db
return s.clone().search.Preload(column, conditions...).db
}

// Set set value by name
Expand Down
4 changes: 2 additions & 2 deletions preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ func Preload(scope *Scope) {
fields := scope.Fields()
isSlice := scope.IndirectValue().Kind() == reflect.Slice

if scope.Search.Preload != nil {
for key, conditions := range scope.Search.Preload {
if scope.Search.preload != nil {
for key, conditions := range scope.Search.preload {
for _, field := range fields {
if field.Name == key && field.Relationship != nil {
results := makeSlice(field.Struct.Type)
Expand Down
16 changes: 8 additions & 8 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ import (
)

type Scope struct {
Value interface{}
indirectValue *reflect.Value
Search *search
Value interface{}
Sql string
SqlVars []interface{}
db *DB
skipLeft bool
primaryKeyField *Field
indirectValue *reflect.Value
instanceId string
primaryKeyField *Field
skipLeft bool
fields map[string]*Field
}

Expand Down Expand Up @@ -225,15 +225,15 @@ func (scope *Scope) AddToVars(value interface{}) string {

// TableName get table name
func (scope *Scope) TableName() string {
if scope.Search != nil && len(scope.Search.TableName) > 0 {
return scope.Search.TableName
if scope.Search != nil && len(scope.Search.tableName) > 0 {
return scope.Search.tableName
}
return scope.GetModelStruct().TableName
}

func (scope *Scope) QuotedTableName() (name string) {
if scope.Search != nil && len(scope.Search.TableName) > 0 {
return scope.Quote(scope.Search.TableName)
if scope.Search != nil && len(scope.Search.tableName) > 0 {
return scope.Quote(scope.Search.tableName)
} else {
return scope.Quote(scope.TableName())
}
Expand Down
60 changes: 30 additions & 30 deletions scope_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
func (scope *Scope) whereSql() (sql string) {
var primaryConditions, andConditions, orConditions []string

if !scope.Search.Unscope && scope.Fields()["deleted_at"] != nil {
if !scope.Search.unscoped && scope.Fields()["deleted_at"] != nil {
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
primaryConditions = append(primaryConditions, sql)
}
Expand All @@ -168,19 +168,19 @@ func (scope *Scope) whereSql() (sql string) {
primaryConditions = append(primaryConditions, scope.primaryCondition(scope.AddToVars(scope.PrimaryKeyValue())))
}

for _, clause := range scope.Search.WhereConditions {
for _, clause := range scope.Search.whereConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}

for _, clause := range scope.Search.OrConditions {
for _, clause := range scope.Search.orConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
orConditions = append(orConditions, sql)
}
}

for _, clause := range scope.Search.NotConditions {
for _, clause := range scope.Search.notConditions {
if sql := scope.buildNotCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
Expand Down Expand Up @@ -208,76 +208,76 @@ func (scope *Scope) whereSql() (sql string) {
}

func (scope *Scope) selectSql() string {
if len(scope.Search.Selects) == 0 {
if len(scope.Search.selects) == 0 {
return "*"
}
return scope.buildSelectQuery(scope.Search.Selects)
return scope.buildSelectQuery(scope.Search.selects)
}

func (scope *Scope) orderSql() string {
if len(scope.Search.Orders) == 0 {
if len(scope.Search.orders) == 0 {
return ""
}
return " ORDER BY " + strings.Join(scope.Search.Orders, ",")
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
}

func (scope *Scope) limitSql() string {
if !scope.Dialect().HasTop() {
if len(scope.Search.Limit) == 0 {
if len(scope.Search.limit) == 0 {
return ""
}
return " LIMIT " + scope.Search.Limit
return " LIMIT " + scope.Search.limit
}

return ""
}

func (scope *Scope) topSql() string {
if scope.Dialect().HasTop() && len(scope.Search.Offset) == 0 {
if len(scope.Search.Limit) == 0 {
if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
if len(scope.Search.limit) == 0 {
return ""
}
return " TOP(" + scope.Search.Limit + ")"
return " TOP(" + scope.Search.limit + ")"
}

return ""
}

func (scope *Scope) offsetSql() string {
if len(scope.Search.Offset) == 0 {
if len(scope.Search.offset) == 0 {
return ""
}

if scope.Dialect().HasTop() {
sql := " OFFSET " + scope.Search.Offset + " ROW "
if len(scope.Search.Limit) > 0 {
sql += "FETCH NEXT " + scope.Search.Limit + " ROWS ONLY"
sql := " OFFSET " + scope.Search.offset + " ROW "
if len(scope.Search.limit) > 0 {
sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
}
return sql
}
return " OFFSET " + scope.Search.Offset
return " OFFSET " + scope.Search.offset
}

func (scope *Scope) groupSql() string {
if len(scope.Search.Group) == 0 {
if len(scope.Search.group) == 0 {
return ""
}
return " GROUP BY " + scope.Search.Group
return " GROUP BY " + scope.Search.group
}

func (scope *Scope) havingSql() string {
if scope.Search.HavingCondition == nil {
if scope.Search.havingCondition == nil {
return ""
}
return " HAVING " + scope.buildWhereCondition(scope.Search.HavingCondition)
return " HAVING " + scope.buildWhereCondition(scope.Search.havingCondition)
}

func (scope *Scope) joinsSql() string {
return scope.Search.Joins + " "
return scope.Search.joins + " "
}

func (scope *Scope) prepareQuerySql() {
if scope.Search.Raw {
if scope.Search.raw {
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
} else {
scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
Expand All @@ -287,7 +287,7 @@ func (scope *Scope) prepareQuerySql() {

func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
if len(values) > 0 {
scope.Search = scope.Search.clone().where(values[0], values[1:]...)
scope.Search.Where(values[0], values[1:]...)
}
return scope
}
Expand Down Expand Up @@ -348,17 +348,17 @@ func (scope *Scope) rows() (*sql.Rows, error) {
}

func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.WhereConditions {
for _, clause := range scope.Search.whereConditions {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
}
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.InitAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.AssignAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
return scope
}

func (scope *Scope) pluck(column string, value interface{}) *Scope {
dest := reflect.Indirect(reflect.ValueOf(value))
scope.Search = scope.Search.clone().selects(column)
scope.Search.Selects(column)
if dest.Kind() != reflect.Slice {
scope.Err(errors.New("results should be a slice"))
return scope
Expand All @@ -377,7 +377,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
}

func (scope *Scope) count(value interface{}) *Scope {
scope.Search = scope.Search.clone().selects("count(*)")
scope.Search.Selects("count(*)")
scope.Err(scope.row().Scan(value))
return scope
}
Expand Down
Loading

0 comments on commit 6e5d46b

Please sign in to comment.