Skip to content

Commit

Permalink
API for search
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jan 28, 2014
1 parent 15583e6 commit 6f1dd5f
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 81 deletions.
2 changes: 1 addition & 1 deletion callback_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func Delete(scope *Scope) {
defer scope.Trace(time.Now())

if !scope.HasError() {
if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
if !scope.Search.Unscope && scope.HasColumn("DeletedAt") {
scope.Raw(
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
scope.TableName(),
Expand Down
8 changes: 4 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
if c.First(out, where...).Error == RecordNotFound {
c.NewScope(out).inlineCondition(where...).initialize()
} else {
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.assignAttrs), false)
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.AssignAttrs), false)
}
return c
}
Expand All @@ -161,8 +161,8 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
c := s.clone()
if c.First(out, where...).Error == RecordNotFound {
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
} else if len(s.search.assignAttrs) > 0 {
c.NewScope(out).Set("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates)
} else if len(s.search.AssignAttrs) > 0 {
c.NewScope(out).Set("gorm:update_interface", s.search.AssignAttrs).callCallbacks(s.parent.callback.updates)
}
return c
}
Expand Down Expand Up @@ -203,7 +203,7 @@ func (s *DB) Delete(value interface{}) *DB {
}

func (s *DB) Raw(sql string, values ...interface{}) *DB {
return s.clone().search.setraw(true).where(sql, values...).db
return s.clone().search.raw(true).where(sql, values...).db
}

func (s *DB) Exec(sql string, values ...interface{}) *DB {
Expand Down
10 changes: 5 additions & 5 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ func (scope *Scope) AddToVars(value interface{}) string {
}

func (scope *Scope) TableName() string {
if scope.Search != nil && len(scope.Search.tableName) > 0 {
return scope.Search.tableName
if scope.Search != nil && len(scope.Search.TableName) > 0 {
return scope.Search.TableName
} else {
if scope.Value == nil {
scope.Err(errors.New("can't get table name"))
Expand Down Expand Up @@ -414,11 +414,11 @@ func (scope *Scope) rows() (*sql.Rows, error) {
}

func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.whereClause {
for _, clause := range scope.Search.WhereConditions {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
}
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.InitAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.AssignAttrs), false)
return scope
}

Expand Down
36 changes: 18 additions & 18 deletions scope_condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,23 @@ func (scope *Scope) where(where ...interface{}) {
func (scope *Scope) whereSql() (sql string) {
var primary_condiations, and_conditions, or_conditions []string

if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
if !scope.Search.Unscope && scope.HasColumn("DeletedAt") {
primary_condiations = append(primary_condiations, "(deleted_at IS NULL OR deleted_at <= '0001-01-02')")
}

if !scope.PrimaryKeyZero() {
primary_condiations = append(primary_condiations, scope.primaryCondiation(scope.AddToVars(scope.PrimaryKeyValue())))
}

for _, clause := range scope.Search.whereClause {
for _, clause := range scope.Search.WhereConditions {
and_conditions = append(and_conditions, scope.buildWhereCondition(clause))
}

for _, clause := range scope.Search.orClause {
for _, clause := range scope.Search.OrConditions {
or_conditions = append(or_conditions, scope.buildWhereCondition(clause))
}

for _, clause := range scope.Search.notClause {
for _, clause := range scope.Search.NotConditions {
and_conditions = append(and_conditions, scope.buildNotCondition(clause))
}

Expand All @@ -179,59 +179,59 @@ func (scope *Scope) whereSql() (sql string) {
}

func (s *Scope) selectSql() string {
if len(s.Search.selectStr) == 0 {
if len(s.Search.Select) == 0 {
return "*"
} else {
return s.Search.selectStr
return s.Search.Select
}
}

func (s *Scope) orderSql() string {
if len(s.Search.orders) == 0 {
if len(s.Search.Orders) == 0 {
return ""
} else {
return " ORDER BY " + strings.Join(s.Search.orders, ",")
return " ORDER BY " + strings.Join(s.Search.Orders, ",")
}
}

func (s *Scope) limitSql() string {
if len(s.Search.limitStr) == 0 {
if len(s.Search.Limit) == 0 {
return ""
} else {
return " LIMIT " + s.Search.limitStr
return " LIMIT " + s.Search.Limit
}
}

func (s *Scope) offsetSql() string {
if len(s.Search.offsetStr) == 0 {
if len(s.Search.Offset) == 0 {
return ""
} else {
return " OFFSET " + s.Search.offsetStr
return " OFFSET " + s.Search.Offset
}
}

func (s *Scope) groupSql() string {
if len(s.Search.groupStr) == 0 {
if len(s.Search.Group) == 0 {
return ""
} else {
return " GROUP BY " + s.Search.groupStr
return " GROUP BY " + s.Search.Group
}
}

func (s *Scope) havingSql() string {
if s.Search.havingClause == nil {
if s.Search.HavingCondition == nil {
return ""
} else {
return " HAVING " + s.buildWhereCondition(s.Search.havingClause)
return " HAVING " + s.buildWhereCondition(s.Search.HavingCondition)
}
}

func (s *Scope) joinsSql() string {
return s.Search.joinsStr + " "
return s.Search.Joins + " "
}

func (scope *Scope) prepareQuerySql() {
if scope.Search.raw {
if scope.Search.Raw {
scope.Raw(strings.TrimLeft(scope.CombinedConditionSql(), "WHERE "))
} else {
scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSql(), scope.TableName(), scope.CombinedConditionSql()))
Expand Down
96 changes: 48 additions & 48 deletions search.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,100 +6,100 @@ import (
)

type search struct {
db *DB
whereClause []map[string]interface{}
orClause []map[string]interface{}
notClause []map[string]interface{}
initAttrs []interface{}
assignAttrs []interface{}
havingClause map[string]interface{}
orders []string
joinsStr string
selectStr string
offsetStr string
limitStr string
groupStr string
tableName string
unscope bool
raw bool
db *DB
WhereConditions []map[string]interface{}
OrConditions []map[string]interface{}
NotConditions []map[string]interface{}
InitAttrs []interface{}
AssignAttrs []interface{}
HavingCondition map[string]interface{}
Orders []string
Joins string
Select string
Offset string
Limit string
Group string
TableName string
Unscope bool
Raw bool
}

func (s *search) clone() *search {
return &search{
whereClause: s.whereClause,
orClause: s.orClause,
notClause: s.notClause,
initAttrs: s.initAttrs,
assignAttrs: s.assignAttrs,
havingClause: s.havingClause,
orders: s.orders,
selectStr: s.selectStr,
offsetStr: s.offsetStr,
limitStr: s.limitStr,
unscope: s.unscope,
groupStr: s.groupStr,
joinsStr: s.joinsStr,
tableName: s.tableName,
raw: s.raw,
WhereConditions: s.WhereConditions,
OrConditions: s.OrConditions,
NotConditions: s.NotConditions,
InitAttrs: s.InitAttrs,
AssignAttrs: s.AssignAttrs,
HavingCondition: s.HavingCondition,
Orders: s.Orders,
Select: s.Select,
Offset: s.Offset,
Limit: s.Limit,
Unscope: s.Unscope,
Group: s.Group,
Joins: s.Joins,
TableName: s.TableName,
Raw: s.Raw,
}
}

func (s *search) where(query interface{}, values ...interface{}) *search {
s.whereClause = append(s.whereClause, map[string]interface{}{"query": query, "args": values})
s.WhereConditions = append(s.WhereConditions, map[string]interface{}{"query": query, "args": values})
return s
}

func (s *search) not(query interface{}, values ...interface{}) *search {
s.notClause = append(s.notClause, map[string]interface{}{"query": query, "args": values})
s.NotConditions = append(s.NotConditions, map[string]interface{}{"query": query, "args": values})
return s
}

func (s *search) or(query interface{}, values ...interface{}) *search {
s.orClause = append(s.orClause, map[string]interface{}{"query": query, "args": values})
s.OrConditions = append(s.OrConditions, map[string]interface{}{"query": query, "args": values})
return s
}

func (s *search) attrs(attrs ...interface{}) *search {
s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
s.InitAttrs = append(s.InitAttrs, toSearchableMap(attrs...))
return s
}

func (s *search) assign(attrs ...interface{}) *search {
s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
s.AssignAttrs = append(s.AssignAttrs, toSearchableMap(attrs...))
return s
}

func (s *search) order(value string, reorder ...bool) *search {
if len(reorder) > 0 && reorder[0] {
s.orders = []string{value}
s.Orders = []string{value}
} else {
s.orders = append(s.orders, value)
s.Orders = append(s.Orders, value)
}
return s
}

func (s *search) selects(value interface{}) *search {
s.selectStr = s.getInterfaceAsSql(value)
s.Select = s.getInterfaceAsSql(value)
return s
}

func (s *search) limit(value interface{}) *search {
s.limitStr = s.getInterfaceAsSql(value)
s.Limit = s.getInterfaceAsSql(value)
return s
}

func (s *search) offset(value interface{}) *search {
s.offsetStr = s.getInterfaceAsSql(value)
s.Offset = s.getInterfaceAsSql(value)
return s
}

func (s *search) group(query string) *search {
s.groupStr = s.getInterfaceAsSql(query)
s.Group = s.getInterfaceAsSql(query)
return s
}

func (s *search) having(query string, values ...interface{}) *search {
s.havingClause = map[string]interface{}{"query": query, "args": values}
s.HavingCondition = map[string]interface{}{"query": query, "args": values}
return s
}

Expand All @@ -108,22 +108,22 @@ func (s *search) includes(value interface{}) *search {
}

func (s *search) joins(query string) *search {
s.joinsStr = query
s.Joins = query
return s
}

func (s *search) setraw(b bool) *search {
s.raw = b
func (s *search) raw(b bool) *search {
s.Raw = b
return s
}

func (s *search) unscoped() *search {
s.unscope = true
s.Unscope = true
return s
}

func (s *search) table(name string) *search {
s.tableName = name
s.TableName = name
return s
}

Expand Down
10 changes: 5 additions & 5 deletions search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ func TestCloneSearch(t *testing.T) {
s1 := s.clone()
s1.where("age = ?", 20).order("age").attrs("email", "a@e.org").selects("email")

if reflect.DeepEqual(s.whereClause, s1.whereClause) {
if reflect.DeepEqual(s.WhereConditions, s1.WhereConditions) {
t.Errorf("Where should be copied")
}

if reflect.DeepEqual(s.orders, s1.orders) {
if reflect.DeepEqual(s.Orders, s1.Orders) {
t.Errorf("Order should be copied")
}

if reflect.DeepEqual(s.initAttrs, s1.initAttrs) {
t.Errorf("initAttrs should be copied")
if reflect.DeepEqual(s.InitAttrs, s1.InitAttrs) {
t.Errorf("InitAttrs should be copied")
}

if reflect.DeepEqual(s.selectStr, s1.selectStr) {
if reflect.DeepEqual(s.Select, s1.Select) {
t.Errorf("selectStr should be copied")
}
}

0 comments on commit 6f1dd5f

Please sign in to comment.