Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Datasource 抽象 #170

Merged
merged 21 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- [eorm: 分库分表: Merger抽象与批量查询实现](https://github.com/ecodeclub/eorm/pull/160)
- [eorm: 增强的 ShardingAlgorithm 设计与实现](https://github.com/ecodeclub/eorm/pull/161)
- [eorm: 分库分表: Merger排序实现](https://github.com/ecodeclub/eorm/pull/166)
- [eorm: Datasource 抽象](https://github.com/ecodeclub/eorm/pull/167)
- [eorm: BasicTypeValue重命名](https://github.com/ecodeclub/eorm/pull/177)

## v0.0.1:
Expand Down
17 changes: 10 additions & 7 deletions aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ import (
)

func TestAggregate(t *testing.T) {
db := memoryDB()
db, err := Open("sqlite3", memoryDB())
if err != nil {
t.Fatal(err)
}
testCases := []CommonTestCase{
{
name: "avg",
Expand Down Expand Up @@ -81,42 +84,42 @@ func TestAggregate(t *testing.T) {
}

func ExampleAggregate_As() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(Avg("Age").As("avg_age")).Build()
fmt.Println(query.SQL)
// Output: SELECT AVG(`age`) AS `avg_age` FROM `test_model`;
}

func ExampleAvg() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(Avg("Age").As("avg_age")).Build()
fmt.Println(query.SQL)
// Output: SELECT AVG(`age`) AS `avg_age` FROM `test_model`;
}

func ExampleCount() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(Count("Age")).Build()
fmt.Println(query.SQL)
// Output: SELECT COUNT(`age`) FROM `test_model`;
}

func ExampleMax() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(Max("Age")).Build()
fmt.Println(query.SQL)
// Output: SELECT MAX(`age`) FROM `test_model`;
}

func ExampleMin() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(Min("Age")).Build()
fmt.Println(query.SQL)
// Output: SELECT MIN(`age`) FROM `test_model`;
}

func ExampleSum() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(Sum("Age")).Build()
fmt.Println(query.SQL)
// Output: SELECT SUM(`age`) FROM `test_model`;
Expand Down
2 changes: 1 addition & 1 deletion assignment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package eorm
import "fmt"

func ExampleAssign() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
tm := &TestModel{}
examples := []struct {
assign Assignment
Expand Down
24 changes: 13 additions & 11 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@ import (
"context"
"database/sql"

"github.com/ecodeclub/eorm/internal/datasource"

"github.com/ecodeclub/eorm/internal/errs"
"github.com/ecodeclub/eorm/internal/model"
"github.com/ecodeclub/eorm/internal/query"
"github.com/valyala/bytebufferpool"
)

var _ Executor = &Inserter[any]{}
var _ Executor = &Updater[any]{}
var _ Executor = &Deleter[any]{}

var EmptyQuery = Query{}

// Query 代表一个查询
type Query struct {
SQL string
Args []any
}
type Query query.Query

// Querier 查询器,代表最基本的查询
type Querier[T any] struct {
Expand All @@ -48,7 +50,7 @@ func RawQuery[T any](sess Session, sql string, args ...any) Querier[T] {
core: sess.getCore(),
Session: sess,
qc: &QueryContext{
q: &Query{
q: Query{
SQL: sql,
Args: args,
},
Expand All @@ -57,7 +59,7 @@ func RawQuery[T any](sess Session, sql string, args ...any) Querier[T] {
}
}

func newQuerier[T any](sess Session, q *Query, meta *model.TableMeta, typ string) Querier[T] {
func newQuerier[T any](sess Session, q Query, meta *model.TableMeta, typ string) Querier[T] {
return Querier[T]{
core: sess.getCore(),
Session: sess,
Expand All @@ -72,7 +74,7 @@ func newQuerier[T any](sess Session, q *Query, meta *model.TableMeta, typ string
// Exec 执行 SQL
func (q Querier[T]) Exec(ctx context.Context) Result {
var handler HandleFunc = func(ctx context.Context, qc *QueryContext) *QueryResult {
res, err := q.Session.execContext(ctx, qc.q.SQL, qc.q.Args...)
res, err := q.Session.execContext(ctx, datasource.Query(qc.q))
Stone-afk marked this conversation as resolved.
Show resolved Hide resolved
return &QueryResult{Result: res, Err: err}
}

Expand Down Expand Up @@ -326,16 +328,16 @@ func (b *builder) buildColumn(c Column) error {
// buildSubquery 構建子查詢 SQL,
// useAlias 決定是否顯示別名,即使有別名
func (b *builder) buildSubquery(sub Subquery, useAlias bool) error {
query, err := sub.q.Build()
q, err := sub.q.Build()
if err != nil {
return err
}
b.writeByte('(')
// 拿掉最後 ';'
b.writeString(query.SQL[:len(query.SQL)-1])
b.writeString(q.SQL[:len(q.SQL)-1])
// 因為有 build() ,所以理應 args 也需要跟 SQL 一起處理
if len(query.Args) > 0 {
b.addArgs(query.Args...)
if len(q.Args) > 0 {
b.addArgs(q.Args...)
}
b.writeByte(')')
if useAlias {
Expand Down
10 changes: 6 additions & 4 deletions builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ import (
"fmt"
"testing"

"github.com/ecodeclub/eorm/internal/datasource/single"

"github.com/DATA-DOG/go-sqlmock"
"github.com/ecodeclub/eorm/internal/errs"
"github.com/ecodeclub/eorm/internal/valuer"
"github.com/stretchr/testify/assert"
)

func ExampleRawQuery() {
orm := memoryDB()
orm, _ := Open("sqlite3", memoryDB())
q := RawQuery[any](orm, `SELECT * FROM user_tab WHERE id = ?;`, 1)
fmt.Printf(`
SQL: %s
Expand All @@ -40,7 +42,7 @@ Args: %v
}

func ExampleQuerier_Exec() {
orm := memoryDB()
orm, _ := Open("sqlite3", memoryDB())
// 在 Exec 的时候,泛型参数可以是任意的
q := RawQuery[any](orm, `CREATE TABLE IF NOT EXISTS groups (
group_id INTEGER PRIMARY KEY,
Expand Down Expand Up @@ -75,7 +77,7 @@ func testQuerierGet(t *testing.T, creator valuer.PrimitiveCreator) {
}
defer func() { _ = db.Close() }()

orm, err := openDB("mysql", db)
orm, err := Open("mysql", single.NewDB(db))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -165,7 +167,7 @@ func testQuerier_GetMulti(t *testing.T, creator valuer.PrimitiveCreator) {
defer func() {
_ = db.Close()
}()
orm, err := openDB("mysql", db)
orm, err := Open("mysql", single.NewDB(db))
if err != nil {
t.Fatal(err)
}
Expand Down
24 changes: 12 additions & 12 deletions column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ package eorm
import "fmt"

func ExampleC() {
db := memoryDB()
query, _ := NewSelector[TestModel](db).Select(C("Id")).Where(C("Id").EQ(18)).Build()
orm, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](orm).Select(C("Id")).Where(C("Id").EQ(18)).Build()
fmt.Printf(`
SQL: %s
Args: %v
Expand All @@ -29,7 +29,7 @@ Args: %v
}

func ExampleColumn_EQ() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(C("Id")).Where(C("Id").EQ(18)).Build()
fmt.Printf(`
SQL: %s
Expand All @@ -41,7 +41,7 @@ Args: %v
}

func ExampleColumn_Add() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
tm := &TestModel{}
query, _ := NewUpdater[TestModel](db).Update(tm).Set(Assign("Age", C("Age").Add(1))).Build()
fmt.Printf(`
Expand All @@ -54,7 +54,7 @@ Args: %v
}

func ExampleColumn_As() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(C("Id").As("my_id")).Build()
fmt.Printf(`
SQL: %s
Expand All @@ -66,7 +66,7 @@ Args: %v
}

func ExampleColumn_GT() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(C("Id")).Where(C("Id").GT(18)).Build()
fmt.Printf(`
SQL: %s
Expand All @@ -78,7 +78,7 @@ Args: %v
}

func ExampleColumn_GTEQ() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(C("Id")).Where(C("Id").GTEQ(18)).Build()
fmt.Printf(`
SQL: %s
Expand All @@ -90,7 +90,7 @@ Args: %v
}

func ExampleColumn_LT() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(C("Id")).Where(C("Id").LT(18)).Build()
fmt.Printf(`
SQL: %s
Expand All @@ -102,7 +102,7 @@ Args: %v
}

func ExampleColumn_LTEQ() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(C("Id")).Where(C("Id").LTEQ(18)).Build()
fmt.Printf(`
SQL: %s
Expand All @@ -114,7 +114,7 @@ Args: %v
}

func ExampleColumn_Multi() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
tm := &TestModel{}
query, _ := NewUpdater[TestModel](db).Update(tm).Set(Assign("Age", C("Age").Multi(2))).Build()
fmt.Printf(`
Expand All @@ -127,7 +127,7 @@ Args: %v
}

func ExampleColumn_NEQ() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(C("Id")).Where(C("Id").NEQ(18)).Build()
fmt.Printf(`
SQL: %s
Expand All @@ -139,7 +139,7 @@ Args: %v
}

func ExampleColumns() {
db := memoryDB()
db, _ := Open("sqlite3", memoryDB())
query, _ := NewSelector[TestModel](db).Select(Columns("Id", "Age")).Build()
fmt.Printf(`
SQL: %s
Expand Down
6 changes: 4 additions & 2 deletions core.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"context"
"reflect"

"github.com/ecodeclub/eorm/internal/datasource"

"github.com/ecodeclub/eorm/internal/dialect"
"github.com/ecodeclub/eorm/internal/errs"
"github.com/ecodeclub/eorm/internal/model"
Expand All @@ -32,7 +34,7 @@ type core struct {
}

func getHandler[T any](ctx context.Context, sess Session, c core, qc *QueryContext) *QueryResult {
rows, err := sess.queryContext(ctx, qc.q.SQL, qc.q.Args...)
rows, err := sess.queryContext(ctx, datasource.Query(qc.q))
if err != nil {
return &QueryResult{Err: err}
}
Expand Down Expand Up @@ -68,7 +70,7 @@ func get[T any](ctx context.Context, sess Session, core core, qc *QueryContext)
}

func getMultiHandler[T any](ctx context.Context, sess Session, c core, qc *QueryContext) *QueryResult {
rows, err := sess.queryContext(ctx, qc.q.SQL, qc.q.Args...)
rows, err := sess.queryContext(ctx, datasource.Query(qc.q))
if err != nil {
return &QueryResult{Err: err}
}
Expand Down
Loading