20
20
// db, err := sql.Open("interceptor", "dsn")
21
21
type Interceptor struct {
22
22
// Driver is a database driver.
23
- // It must implement [driver.ExecerContext] and [driver.QueryerContext] (most drivers do).
23
+ // It must implement [driver.Pinger], [driver.ExecerContext], [driver.QueryerContext],
24
+ // [driver.ConnPrepareContext], and [driver.ConnBeginTx] (most drivers do).
24
25
// Required.
25
26
Driver driver.Driver
26
27
@@ -60,16 +61,27 @@ func (i Interceptor) OpenConnector(name string) (driver.Connector, error) {
60
61
61
62
var (
62
63
_ driver.Conn = wrappedConn {}
64
+ _ driver.Pinger = wrappedConn {}
63
65
_ driver.ExecerContext = wrappedConn {}
64
66
_ driver.QueryerContext = wrappedConn {}
65
67
_ driver.ConnPrepareContext = wrappedConn {}
68
+ _ driver.ConnBeginTx = wrappedConn {}
66
69
)
67
70
68
71
type wrappedConn struct {
69
72
driver.Conn
70
73
interceptor Interceptor
71
74
}
72
75
76
+ // Ping implements [driver.Pinger].
77
+ func (c wrappedConn ) Ping (ctx context.Context ) error {
78
+ pinger , ok := c .Conn .(driver.Pinger )
79
+ if ! ok {
80
+ panic ("queries: driver does not implement driver.Pinger" )
81
+ }
82
+ return pinger .Ping (ctx )
83
+ }
84
+
73
85
// ExecContext implements [driver.ExecerContext].
74
86
func (c wrappedConn ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
75
87
execer , ok := c .Conn .(driver.ExecerContext )
@@ -94,8 +106,6 @@ func (c wrappedConn) QueryContext(ctx context.Context, query string, args []driv
94
106
return queryer .QueryContext (ctx , query , args )
95
107
}
96
108
97
- var _ driver.Connector = wrappedConnector {}
98
-
99
109
// PrepareContext implements [driver.ConnPrepareContext].
100
110
func (c wrappedConn ) PrepareContext (ctx context.Context , query string ) (driver.Stmt , error ) {
101
111
preparer , ok := c .Conn .(driver.ConnPrepareContext )
@@ -108,6 +118,44 @@ func (c wrappedConn) PrepareContext(ctx context.Context, query string) (driver.S
108
118
return preparer .PrepareContext (ctx , query )
109
119
}
110
120
121
+ // BeginTx implements [driver.ConnBeginTx].
122
+ func (c wrappedConn ) BeginTx (ctx context.Context , opts driver.TxOptions ) (driver.Tx , error ) {
123
+ beginner , ok := c .Conn .(driver.ConnBeginTx )
124
+ if ! ok {
125
+ panic ("queries: driver does not implement driver.ConnBeginTx" )
126
+ }
127
+ return beginner .BeginTx (ctx , opts )
128
+ }
129
+
130
+ var _ driver.SessionResetter = wrappedConnSessionResetter {}
131
+
132
+ type wrappedConnSessionResetter struct { wrappedConn }
133
+
134
+ // ResetSession implements [driver.SessionResetter].
135
+ func (c wrappedConnSessionResetter ) ResetSession (ctx context.Context ) error {
136
+ return c .Conn .(driver.SessionResetter ).ResetSession (ctx )
137
+ }
138
+
139
+ var _ driver.Validator = wrappedConnValidator {}
140
+
141
+ type wrappedConnValidator struct { wrappedConn }
142
+
143
+ // IsValid implements [driver.Validator].
144
+ func (c wrappedConnValidator ) IsValid () bool {
145
+ return c .Conn .(driver.Validator ).IsValid ()
146
+ }
147
+
148
+ var _ driver.NamedValueChecker = wrappedConnNamedValueChecker {}
149
+
150
+ type wrappedConnNamedValueChecker struct { wrappedConn }
151
+
152
+ // CheckNamedValue implements [driver.NamedValueChecker].
153
+ func (c wrappedConnNamedValueChecker ) CheckNamedValue (nv * driver.NamedValue ) error {
154
+ return c .Conn .(driver.NamedValueChecker ).CheckNamedValue (nv )
155
+ }
156
+
157
+ var _ driver.Connector = wrappedConnector {}
158
+
111
159
type wrappedConnector struct {
112
160
driver.Connector
113
161
interceptor Interceptor
@@ -119,7 +167,64 @@ func (c wrappedConnector) Connect(ctx context.Context) (driver.Conn, error) {
119
167
if err != nil {
120
168
return nil , err
121
169
}
122
- return wrappedConn {conn , c .interceptor }, nil
170
+
171
+ wconn := wrappedConn {conn , c .interceptor }
172
+ _ , isSessionResetter := conn .(driver.SessionResetter )
173
+ _ , isValidator := conn .(driver.Validator )
174
+ _ , isNamedValueChecker := conn .(driver.NamedValueChecker )
175
+
176
+ switch {
177
+ case isSessionResetter && isValidator && isNamedValueChecker :
178
+ return struct {
179
+ wrappedConn
180
+ wrappedConnSessionResetter
181
+ wrappedConnValidator
182
+ wrappedConnNamedValueChecker
183
+ }{
184
+ wconn ,
185
+ wrappedConnSessionResetter {wconn },
186
+ wrappedConnValidator {wconn },
187
+ wrappedConnNamedValueChecker {wconn },
188
+ }, nil
189
+ case isSessionResetter && isValidator :
190
+ return struct {
191
+ wrappedConn
192
+ wrappedConnSessionResetter
193
+ wrappedConnValidator
194
+ }{
195
+ wconn ,
196
+ wrappedConnSessionResetter {wconn },
197
+ wrappedConnValidator {wconn },
198
+ }, nil
199
+ case isSessionResetter && isNamedValueChecker :
200
+ return struct {
201
+ wrappedConn
202
+ wrappedConnSessionResetter
203
+ wrappedConnNamedValueChecker
204
+ }{
205
+ wconn ,
206
+ wrappedConnSessionResetter {wconn },
207
+ wrappedConnNamedValueChecker {wconn },
208
+ }, nil
209
+ case isValidator && isNamedValueChecker :
210
+ return struct {
211
+ wrappedConn
212
+ wrappedConnValidator
213
+ wrappedConnNamedValueChecker
214
+ }{
215
+ wconn ,
216
+ wrappedConnValidator {wconn },
217
+ wrappedConnNamedValueChecker {wconn },
218
+ }, nil
219
+ case isSessionResetter :
220
+ return wrappedConnSessionResetter {wconn }, nil
221
+ case isValidator :
222
+ return wrappedConnValidator {wconn }, nil
223
+ case isNamedValueChecker :
224
+ return wrappedConnNamedValueChecker {wconn }, nil
225
+ default :
226
+ return wconn , nil
227
+ }
123
228
}
124
229
125
230
// copied from https://go.dev/src/database/sql/sql.go
@@ -128,5 +233,5 @@ type dsnConnector struct {
128
233
driver driver.Driver
129
234
}
130
235
131
- func (t dsnConnector ) Connect (_ context.Context ) (driver.Conn , error ) { return t .driver .Open (t .dsn ) }
132
- func (t dsnConnector ) Driver () driver.Driver { return t .driver }
236
+ func (t dsnConnector ) Connect (context.Context ) (driver.Conn , error ) { return t .driver .Open (t .dsn ) }
237
+ func (t dsnConnector ) Driver () driver.Driver { return t .driver }
0 commit comments