From e5432b14d2f28ad759d8d6262c1f8a167d517f73 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Aug 2017 07:41:43 +0800 Subject: [PATCH] Add QueryExpr, thanks @ManReinsp for PR #1548 --- create_test.go | 4 ++-- main.go | 11 ++++++++++- main_test.go | 47 ++++++++++++++++++++++++++++++++++++++++++++++- scope.go | 20 +++++++++++++++----- search.go | 8 ++++++-- 5 files changed, 79 insertions(+), 11 deletions(-) diff --git a/create_test.go b/create_test.go index 38d75af82..36472914b 100644 --- a/create_test.go +++ b/create_test.go @@ -68,11 +68,11 @@ func TestCreateWithExistingTimestamp(t *testing.T) { user.UpdatedAt = timeA DB.Save(&user) - if user.CreatedAt.Format(time.RFC3339) != timeA.Format(time.RFC3339) { + if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { t.Errorf("CreatedAt should not be changed") } - if user.UpdatedAt.Format(time.RFC3339) != timeA.Format(time.RFC3339) { + if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { t.Errorf("UpdatedAt should not be changed") } diff --git a/main.go b/main.go index 0f2fd1f52..6dc192b95 100644 --- a/main.go +++ b/main.go @@ -168,6 +168,15 @@ func (s *DB) NewScope(value interface{}) *Scope { return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} } +// QueryExpr returns the query as expr object +func (s *DB) QueryExpr() *expr { + scope := s.NewScope(s.Value) + scope.InstanceSet("skip_bindvar", true) + scope.prepareQuerySQL() + + return Expr("("+scope.SQL+")", scope.SQLVars...) +} + // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db @@ -218,7 +227,7 @@ func (s *DB) Group(query string) *DB { } // Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query string, values ...interface{}) *DB { +func (s *DB) Having(query interface{}, values ...interface{}) *DB { return s.clone().search.Having(query, values...).db } diff --git a/main_test.go b/main_test.go index 3b1433cf5..62ad1c472 100644 --- a/main_test.go +++ b/main_test.go @@ -607,9 +607,54 @@ func TestHaving(t *testing.T) { } } +func TestQueryBuilderSubselectInWhere(t *testing.T) { + user := User{Name: "ruser1", Email: "root@user1.com", Age: 32} + DB.Save(&user) + user = User{Name: "ruser2", Email: "nobody@user2.com", Age: 16} + DB.Save(&user) + user = User{Name: "ruser3", Email: "root@user3.com", Age: 64} + DB.Save(&user) + user = User{Name: "ruser4", Email: "somebody@user3.com", Age: 128} + DB.Save(&user) + + var users []User + DB.Select("*").Where("name IN (?)", DB. + Select("name").Table("users").Where("email LIKE ?", "root@%").SubqueryExpr()).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } + + DB.Select("*").Where("email LIKE ?", "root%").Where("age >= (?)", DB. + Select("AVG(age)").Table("users").SubqueryExpr()).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } +} + +func TestQueryBuilderSubselectInHaving(t *testing.T) { + user := User{Name: "ruser1", Email: "root@user1.com", Age: 64} + DB.Save(&user) + user = User{Name: "ruser2", Email: "root@user2.com", Age: 128} + DB.Save(&user) + user = User{Name: "ruser3", Email: "root@user1.com", Age: 64} + DB.Save(&user) + user = User{Name: "ruser4", Email: "root@user2.com", Age: 128} + DB.Save(&user) + + var users []User + DB.Select("AVG(age) as avgage").Where("email LIKE ?", "root%").Group("email").Having("AVG(age) > (?)", DB. + Select("AVG(age)").Where("email LIKE ?", "root%").Table("users").SubqueryExpr()).Find(&users) + + if len(users) != 1 { + t.Errorf("One user group should be found, instead found %d", len(users)) + } +} + func DialectHasTzSupport() bool { // NB: mssql and FoundationDB do not support time zones. - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { + if dialect := os.Getenv("GORM_DIALECT"); dialect == "foundation" { return false } return true diff --git a/scope.go b/scope.go index 4fcb84c11..fda7f6539 100644 --- a/scope.go +++ b/scope.go @@ -253,15 +253,25 @@ func (scope *Scope) CallMethod(methodName string) { // AddToVars add value as sql's vars, used to prevent SQL injection func (scope *Scope) AddToVars(value interface{}) string { + _, skipBindVar := scope.InstanceGet("skip_bindvar") + if expr, ok := value.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + if skipBindVar { + scope.AddToVars(arg) + } else { + exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + } } return exp } scope.SQLVars = append(scope.SQLVars, value) + + if skipBindVar { + return "?" + } return scope.Dialect().BindVar(len(scope.SQLVars)) } @@ -329,12 +339,12 @@ func (scope *Scope) QuotedTableName() (name string) { // CombinedConditionSql return combined condition sql func (scope *Scope) CombinedConditionSql() string { - joinSql := scope.joinsSQL() - whereSql := scope.whereSQL() + joinSQL := scope.joinsSQL() + whereSQL := scope.whereSQL() if scope.Search.raw { - whereSql = strings.TrimSuffix(strings.TrimPrefix(whereSql, "WHERE ("), ")") + whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")") } - return joinSql + whereSql + scope.groupSQL() + + return joinSQL + whereSQL + scope.groupSQL() + scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() } diff --git a/search.go b/search.go index 23dac2c38..2e2735849 100644 --- a/search.go +++ b/search.go @@ -104,8 +104,12 @@ func (s *search) Group(query string) *search { return s } -func (s *search) Having(query string, values ...interface{}) *search { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) +func (s *search) Having(query interface{}, values ...interface{}) *search { + if val, ok := query.(*expr); ok { + s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) + } else { + s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) + } return s }