Skip to content

Commit 702e525

Browse files
committed
fix: build
1 parent cbbe1e9 commit 702e525

File tree

3 files changed

+28
-19
lines changed

3 files changed

+28
-19
lines changed

db.go

+18-9
Original file line numberDiff line numberDiff line change
@@ -40,29 +40,38 @@ func WithReadOnlyReplica(replica *sql.DB) DBOption {
4040
}
4141

4242
type DB struct {
43+
// Must be a pointer so we copy the state, not the state fields.
44+
*noCopyState
45+
46+
queryHooks []QueryHook
47+
48+
fmter schema.Formatter
49+
stats DBStats
50+
}
51+
52+
// noCopyState contains DB fields that must not be copied on clone(),
53+
// for example, it is forbidden to copy atomic.Pointer.
54+
type noCopyState struct {
4355
*sql.DB
56+
dialect schema.Dialect
4457

4558
replicas []*sql.DB
4659
healthyReplicas atomic.Pointer[[]*sql.DB]
4760
nextReplica atomic.Int64
4861

49-
dialect schema.Dialect
50-
queryHooks []QueryHook
51-
52-
fmter schema.Formatter
5362
flags internal.Flag
5463
closed atomic.Bool
55-
56-
stats DBStats
5764
}
5865

5966
func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
6067
dialect.Init(sqldb)
6168

6269
db := &DB{
63-
DB: sqldb,
64-
dialect: dialect,
65-
fmter: schema.NewFormatter(dialect),
70+
noCopyState: &noCopyState{
71+
DB: sqldb,
72+
dialect: dialect,
73+
},
74+
fmter: schema.NewFormatter(dialect),
6675
}
6776

6877
for _, opt := range opts {

query_base.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,8 @@ func (q *baseQuery) GetTableName() string {
147147
}
148148

149149
for _, wq := range q.with {
150-
if v, ok := wq.query.(Query); ok {
151-
if model := v.GetModel(); model != nil {
152-
return v.GetTableName()
153-
}
150+
if model := wq.query.GetModel(); model != nil {
151+
return wq.query.GetTableName()
154152
}
155153
}
156154

query_select.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) {
748748
query := internal.String(queryBytes)
749749

750750
ctx, event := q.db.beforeQuery(ctx, q, query, nil, query, q.model)
751-
rows, err := q.conn.QueryContext(ctx, query)
751+
rows, err := q.resolveConn(q).QueryContext(ctx, query)
752752
q.db.afterQuery(ctx, event, nil, err)
753753
return rows, err
754754
}
@@ -876,7 +876,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) {
876876
ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model)
877877

878878
var num int
879-
err = q.conn.QueryRowContext(ctx, query).Scan(&num)
879+
err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&num)
880880

881881
q.db.afterQuery(ctx, event, nil, err)
882882

@@ -894,13 +894,15 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in
894894
return int(n), nil
895895
}
896896
}
897-
if _, ok := q.conn.(*DB); ok {
898-
return q.scanAndCountConc(ctx, dest...)
897+
if q.conn == nil {
898+
return q.scanAndCountConcurrently(ctx, dest...)
899899
}
900900
return q.scanAndCountSeq(ctx, dest...)
901901
}
902902

903-
func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) {
903+
func (q *SelectQuery) scanAndCountConcurrently(
904+
ctx context.Context, dest ...interface{},
905+
) (int, error) {
904906
var count int
905907
var wg sync.WaitGroup
906908
var mu sync.Mutex
@@ -978,7 +980,7 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) {
978980
ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model)
979981

980982
var exists bool
981-
err = q.conn.QueryRowContext(ctx, query).Scan(&exists)
983+
err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&exists)
982984

983985
q.db.afterQuery(ctx, event, nil, err)
984986

0 commit comments

Comments
 (0)