Skip to content

Commit

Permalink
Fix overriding of view name in Querier.
Browse files Browse the repository at this point in the history
  • Loading branch information
AlekSi committed Dec 12, 2018
1 parent d194063 commit 27faa85
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 20 deletions.
25 changes: 13 additions & 12 deletions querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (

// Querier performs queries and commands.
type Querier struct {
dbtx DBTX
tag string
viewName string
dbtx DBTX
tag string
qualifiedViewName string
Dialect
Logger Logger
}
Expand Down Expand Up @@ -51,24 +51,25 @@ func (q *Querier) WithTag(format string, args ...interface{}) *Querier {
} else {
newQ.tag = fmt.Sprintf(format, args...)
}
newQ.qualifiedViewName = q.qualifiedViewName
return newQ
}

// WithView returns a copy of Querier with appointed view name.
func (q *Querier) WithView(viewName string) *Querier {
// WithQualifiedViewName returns a copy of Querier with set qualified view name.
// Returned Querier is tied to the same DB or TX.
// TODO Support INSERT/UPDATE/DELETE. More test.
func (q *Querier) WithQualifiedViewName(qualifiedViewName string) *Querier {
newQ := newQuerier(q.dbtx, q.Dialect, q.Logger)
newQ.viewName = viewName
newQ.tag = q.tag
newQ.qualifiedViewName = qualifiedViewName
return newQ
}

// QualifiedView returns quoted qualified view name.
// QualifiedView returns quoted qualified view name of given view.
func (q *Querier) QualifiedView(view View) string {
v := q.QuoteIdentifier(view.Name())
if q.viewName != "" {
v = q.QuoteIdentifier(q.viewName)
}
if view.Schema() != "" {
v = q.QuoteIdentifier(view.Schema()) + "." + v
if s := view.Schema(); s != "" {
v = q.QuoteIdentifier(s) + "." + v
}
return v
}
Expand Down
14 changes: 8 additions & 6 deletions querier_examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,18 @@ func ExampleQuerier_WithTag() {
// Name: `Vicious Baron` (string), ID: `baron` (string), Start: 2014-06-01 00:00:00 +0000 UTC (time.Time), End: 2016-02-21 00:00:00 +0000 UTC (*time.Time)
}

func ExampleQuerier_WithView() {
id := 1
view := fmt.Sprintf("people_%d", id%3)
person, err := DB.WithView(view).FindByPrimaryKeyFrom(PersonTable, id)
func ExampleQuerier_WithQualifiedViewName() {
_, err := DB.WithQualifiedViewName("people_0").FindByPrimaryKeyFrom(PersonTable, 1)
if err != reform.ErrNoRows {
log.Fatal(err)
}
person, err := DB.WithQualifiedViewName("people_1").FindByPrimaryKeyFrom(PersonTable, 1)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s: %s", view, person)
fmt.Println(person)
// Output:
// people_1: ID: 1 (int32), GroupID: 65534 (*int32), Name: `Denis Mills` (string), Email: <nil> (*string), CreatedAt: 2009-11-10 23:00:00 +0000 UTC (time.Time), UpdatedAt: <nil> (*time.Time)
// ID: 1 (int32), GroupID: 65534 (*int32), Name: `Denis Mills` (string), Email: <nil> (*string), CreatedAt: 2009-11-10 23:00:00 +0000 UTC (time.Time), UpdatedAt: <nil> (*time.Time)
}

func ExampleQuerier_SelectRows() {
Expand Down
8 changes: 6 additions & 2 deletions querier_selects.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ func (q *Querier) selectQuery(view View, tail string, limit1 bool) string {
query += " TOP 1"
}

return fmt.Sprintf("%s %s FROM %s %s",
query, strings.Join(q.QualifiedColumns(view), ", "), q.QualifiedView(view), tail)
from := q.QualifiedView(view)
if q.qualifiedViewName != "" {
from = q.qualifiedViewName + " AS " + from
}

return fmt.Sprintf("%s %s FROM %s %s", query, strings.Join(q.QualifiedColumns(view), ", "), from, tail)
}

// SelectOneTo queries str's View with tail and args and scans first result to str.
Expand Down
35 changes: 35 additions & 0 deletions querier_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package reform_test

import (
"gopkg.in/reform.v1/dialects/mssql"
"gopkg.in/reform.v1/dialects/mysql"
"gopkg.in/reform.v1/dialects/postgresql"
"gopkg.in/reform.v1/dialects/sqlite3"
"gopkg.in/reform.v1/dialects/sqlserver"
. "gopkg.in/reform.v1/internal/test/models"
)

func (s *ReformSuite) TestQualifiedView() {
switch s.q.Dialect {
case postgresql.Dialect:
s.Equal(`"people"`, s.q.QualifiedView(PersonTable))
s.Equal(`"people"`, s.q.WithQualifiedViewName("ignored").QualifiedView(PersonTable))
s.Equal(`"legacy"."people"`, s.q.QualifiedView(LegacyPersonTable))
s.Equal(`"legacy"."people"`, s.q.WithQualifiedViewName("ignored").QualifiedView(LegacyPersonTable))

case mysql.Dialect:
s.Equal("`people`", s.q.QualifiedView(PersonTable))
s.Equal("`people`", s.q.WithQualifiedViewName("ignored").QualifiedView(PersonTable))

case sqlite3.Dialect:
s.Equal(`"people"`, s.q.QualifiedView(PersonTable))
s.Equal(`"people"`, s.q.WithQualifiedViewName("ignored").QualifiedView(PersonTable))

case mssql.Dialect, sqlserver.Dialect:
s.Equal(`[people]`, s.q.QualifiedView(PersonTable))
s.Equal(`[people]`, s.q.WithQualifiedViewName("ignored").QualifiedView(PersonTable))

default:
s.Fail("Unhandled dialect", s.q.Dialect.String())
}
}

0 comments on commit 27faa85

Please sign in to comment.