Skip to content

Commit 19bce79

Browse files
authored
fix(interceptor): implement optional interfaces for Conn (#18)
1 parent a3c1271 commit 19bce79

File tree

2 files changed

+112
-7
lines changed

2 files changed

+112
-7
lines changed

interceptor.go

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ var (
2020
// db, err := sql.Open("interceptor", "dsn")
2121
type Interceptor struct {
2222
// Driver is a database driver.
23-
// It must implement [driver.ExecerContext] and [driver.QueryerContext] (most drivers do).
23+
// It must implement [driver.Pinger], [driver.ExecerContext], [driver.QueryerContext],
24+
// [driver.ConnPrepareContext], and [driver.ConnBeginTx] (most drivers do).
2425
// Required.
2526
Driver driver.Driver
2627

@@ -60,16 +61,27 @@ func (i Interceptor) OpenConnector(name string) (driver.Connector, error) {
6061

6162
var (
6263
_ driver.Conn = wrappedConn{}
64+
_ driver.Pinger = wrappedConn{}
6365
_ driver.ExecerContext = wrappedConn{}
6466
_ driver.QueryerContext = wrappedConn{}
6567
_ driver.ConnPrepareContext = wrappedConn{}
68+
_ driver.ConnBeginTx = wrappedConn{}
6669
)
6770

6871
type wrappedConn struct {
6972
driver.Conn
7073
interceptor Interceptor
7174
}
7275

76+
// Ping implements [driver.Pinger].
77+
func (c wrappedConn) Ping(ctx context.Context) error {
78+
pinger, ok := c.Conn.(driver.Pinger)
79+
if !ok {
80+
panic("queries: driver does not implement driver.Pinger")
81+
}
82+
return pinger.Ping(ctx)
83+
}
84+
7385
// ExecContext implements [driver.ExecerContext].
7486
func (c wrappedConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
7587
execer, ok := c.Conn.(driver.ExecerContext)
@@ -94,8 +106,6 @@ func (c wrappedConn) QueryContext(ctx context.Context, query string, args []driv
94106
return queryer.QueryContext(ctx, query, args)
95107
}
96108

97-
var _ driver.Connector = wrappedConnector{}
98-
99109
// PrepareContext implements [driver.ConnPrepareContext].
100110
func (c wrappedConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
101111
preparer, ok := c.Conn.(driver.ConnPrepareContext)
@@ -108,6 +118,44 @@ func (c wrappedConn) PrepareContext(ctx context.Context, query string) (driver.S
108118
return preparer.PrepareContext(ctx, query)
109119
}
110120

121+
// BeginTx implements [driver.ConnBeginTx].
122+
func (c wrappedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
123+
beginner, ok := c.Conn.(driver.ConnBeginTx)
124+
if !ok {
125+
panic("queries: driver does not implement driver.ConnBeginTx")
126+
}
127+
return beginner.BeginTx(ctx, opts)
128+
}
129+
130+
var _ driver.SessionResetter = wrappedConnSessionResetter{}
131+
132+
type wrappedConnSessionResetter struct{ wrappedConn }
133+
134+
// ResetSession implements [driver.SessionResetter].
135+
func (c wrappedConnSessionResetter) ResetSession(ctx context.Context) error {
136+
return c.Conn.(driver.SessionResetter).ResetSession(ctx)
137+
}
138+
139+
var _ driver.Validator = wrappedConnValidator{}
140+
141+
type wrappedConnValidator struct{ wrappedConn }
142+
143+
// IsValid implements [driver.Validator].
144+
func (c wrappedConnValidator) IsValid() bool {
145+
return c.Conn.(driver.Validator).IsValid()
146+
}
147+
148+
var _ driver.NamedValueChecker = wrappedConnNamedValueChecker{}
149+
150+
type wrappedConnNamedValueChecker struct{ wrappedConn }
151+
152+
// CheckNamedValue implements [driver.NamedValueChecker].
153+
func (c wrappedConnNamedValueChecker) CheckNamedValue(nv *driver.NamedValue) error {
154+
return c.Conn.(driver.NamedValueChecker).CheckNamedValue(nv)
155+
}
156+
157+
var _ driver.Connector = wrappedConnector{}
158+
111159
type wrappedConnector struct {
112160
driver.Connector
113161
interceptor Interceptor
@@ -119,7 +167,64 @@ func (c wrappedConnector) Connect(ctx context.Context) (driver.Conn, error) {
119167
if err != nil {
120168
return nil, err
121169
}
122-
return wrappedConn{conn, c.interceptor}, nil
170+
171+
wconn := wrappedConn{conn, c.interceptor}
172+
_, isSessionResetter := conn.(driver.SessionResetter)
173+
_, isValidator := conn.(driver.Validator)
174+
_, isNamedValueChecker := conn.(driver.NamedValueChecker)
175+
176+
switch {
177+
case isSessionResetter && isValidator && isNamedValueChecker:
178+
return struct {
179+
wrappedConn
180+
wrappedConnSessionResetter
181+
wrappedConnValidator
182+
wrappedConnNamedValueChecker
183+
}{
184+
wconn,
185+
wrappedConnSessionResetter{wconn},
186+
wrappedConnValidator{wconn},
187+
wrappedConnNamedValueChecker{wconn},
188+
}, nil
189+
case isSessionResetter && isValidator:
190+
return struct {
191+
wrappedConn
192+
wrappedConnSessionResetter
193+
wrappedConnValidator
194+
}{
195+
wconn,
196+
wrappedConnSessionResetter{wconn},
197+
wrappedConnValidator{wconn},
198+
}, nil
199+
case isSessionResetter && isNamedValueChecker:
200+
return struct {
201+
wrappedConn
202+
wrappedConnSessionResetter
203+
wrappedConnNamedValueChecker
204+
}{
205+
wconn,
206+
wrappedConnSessionResetter{wconn},
207+
wrappedConnNamedValueChecker{wconn},
208+
}, nil
209+
case isValidator && isNamedValueChecker:
210+
return struct {
211+
wrappedConn
212+
wrappedConnValidator
213+
wrappedConnNamedValueChecker
214+
}{
215+
wconn,
216+
wrappedConnValidator{wconn},
217+
wrappedConnNamedValueChecker{wconn},
218+
}, nil
219+
case isSessionResetter:
220+
return wrappedConnSessionResetter{wconn}, nil
221+
case isValidator:
222+
return wrappedConnValidator{wconn}, nil
223+
case isNamedValueChecker:
224+
return wrappedConnNamedValueChecker{wconn}, nil
225+
default:
226+
return wconn, nil
227+
}
123228
}
124229

125230
// copied from https://go.dev/src/database/sql/sql.go
@@ -128,5 +233,5 @@ type dsnConnector struct {
128233
driver driver.Driver
129234
}
130235

131-
func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { return t.driver.Open(t.dsn) }
132-
func (t dsnConnector) Driver() driver.Driver { return t.driver }
236+
func (t dsnConnector) Connect(context.Context) (driver.Conn, error) { return t.driver.Open(t.dsn) }
237+
func (t dsnConnector) Driver() driver.Driver { return t.driver }

tests/integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func TestIntegration(t *testing.T) {
8282

8383
assert.NoErr[F](t, migrate(ctx, db))
8484

85-
tx, err := db.BeginTx(ctx, nil)
85+
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
8686
assert.NoErr[F](t, err)
8787
defer tx.Rollback()
8888

0 commit comments

Comments
 (0)