@@ -3,8 +3,11 @@ package driver
3
3
import (
4
4
"context"
5
5
"database/sql"
6
+ sqlDriver "database/sql/driver"
6
7
"fmt"
8
+ "math"
7
9
"net"
10
+ "reflect"
8
11
"strings"
9
12
"testing"
10
13
"time"
@@ -94,16 +97,85 @@ func TestDriverOptions_writeTimeout(t *testing.T) {
94
97
srv := CreateMockServer (t )
95
98
defer srv .Stop ()
96
99
100
+ // use a writeTimeout that will fail parsing by ParseDuration resulting
101
+ // in a conn open error. The Open() won't fail until the ExecContext()
102
+ // call though, because that's when golang database/sql package will open
103
+ // the actual connection.
97
104
conn , err := sql .Open ("mysql" , "root@127.0.0.1:3307/test?writeTimeout=10" )
98
105
require .NoError (t , err )
106
+ require .NotNil (t , conn )
99
107
100
- result , err := conn .ExecContext (context .TODO (), "insert into slow(a,b) values(1,2);" )
108
+ // here we should fail because of the missing time unit in the duration.
109
+ result , err := conn .ExecContext (context .TODO (), "select 1;" )
110
+ require .Contains (t , err .Error (), "missing unit in duration" )
111
+ require .Error (t , err )
101
112
require .Nil (t , result )
113
+
114
+ // use an almost zero (1ns) writeTimeout to ensure the insert statement
115
+ // can't write before the timeout. Just want to make sure ExecContext()
116
+ // will throw an error.
117
+ conn , err = sql .Open ("mysql" , "root@127.0.0.1:3307/test?writeTimeout=1ns" )
118
+ require .NoError (t , err )
119
+
120
+ // ExecContext() should fail due to the write timeout of 1ns
121
+ result , err = conn .ExecContext (context .TODO (), "insert into slow(a,b) values(1,2);" )
102
122
require .Error (t , err )
123
+ require .Contains (t , err .Error (), "i/o timeout" )
124
+ require .Nil (t , result )
103
125
104
126
conn .Close ()
105
127
}
106
128
129
+ func TestDriverOptions_namedValueChecker (t * testing.T ) {
130
+ AddNamedValueChecker (func (nv * sqlDriver.NamedValue ) error {
131
+ rv := reflect .ValueOf (nv .Value )
132
+ if rv .Kind () != reflect .Uint64 {
133
+ // fallback to the default value converter when the value is not a uint64
134
+ return sqlDriver .ErrSkip
135
+ }
136
+
137
+ return nil
138
+ })
139
+
140
+ log .SetLevel (log .LevelDebug )
141
+ srv := CreateMockServer (t )
142
+ defer srv .Stop ()
143
+ conn , err := sql .Open ("mysql" , "root@127.0.0.1:3307/test?writeTimeout=1s" )
144
+ require .NoError (t , err )
145
+ defer conn .Close ()
146
+
147
+ // the NamedValueChecker will return ErrSkip for types that are NOT uint64, so make
148
+ // sure those make it to the server ok first.
149
+ int32Stmt , err := conn .Prepare ("select ?" )
150
+ require .NoError (t , err )
151
+ defer int32Stmt .Close ()
152
+ r1 , err := int32Stmt .Query (math .MaxInt32 )
153
+ require .NoError (t , err )
154
+ require .NotNil (t , r1 )
155
+
156
+ var i32 int32
157
+ require .True (t , r1 .Next ())
158
+ require .NoError (t , r1 .Scan (& i32 ))
159
+ require .True (t , math .MaxInt32 == i32 )
160
+
161
+ // Now make sure that the uint64 makes it to the server as well, this case will be handled
162
+ // by the NamedValueChecker (i.e. it will not return ErrSkip)
163
+ stmt , err := conn .Prepare ("select a, b from fast where uint64 = ?" )
164
+ require .NoError (t , err )
165
+ defer stmt .Close ()
166
+
167
+ var val uint64 = math .MaxUint64
168
+ result , err := stmt .Query (val )
169
+ require .NoError (t , err )
170
+ require .NotNil (t , result )
171
+
172
+ var a uint64
173
+ var b string
174
+ require .True (t , result .Next ())
175
+ require .NoError (t , result .Scan (& a , & b ))
176
+ require .True (t , math .MaxUint64 == a )
177
+ }
178
+
107
179
func CreateMockServer (t * testing.T ) * testServer {
108
180
inMemProvider := server .NewInMemoryProvider ()
109
181
inMemProvider .AddUser (* testUser , * testPassword )
@@ -151,7 +223,7 @@ func (h *mockHandler) UseDB(dbName string) error {
151
223
return nil
152
224
}
153
225
154
- func (h * mockHandler ) handleQuery (query string , binary bool ) (* mysql.Result , error ) {
226
+ func (h * mockHandler ) handleQuery (query string , binary bool , args [] interface {} ) (* mysql.Result , error ) {
155
227
ss := strings .Split (query , " " )
156
228
switch strings .ToLower (ss [0 ]) {
157
229
case "select" :
@@ -163,13 +235,24 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err
163
235
{mysql .MaxPayloadLen },
164
236
}, binary )
165
237
} else {
166
- if strings .Contains (query , "slow" ) {
167
- time .Sleep (time .Second * 5 )
168
- }
238
+ if ss [1 ] == "?" {
239
+ r , err = mysql .BuildSimpleResultset ([]string {"a" }, [][]interface {}{
240
+ {args [0 ].(int64 )},
241
+ }, binary )
242
+ } else {
243
+ if strings .Contains (query , "slow" ) {
244
+ time .Sleep (time .Second * 5 )
245
+ }
169
246
170
- r , err = mysql .BuildSimpleResultset ([]string {"a" , "b" }, [][]interface {}{
171
- {1 , "hello world" },
172
- }, binary )
247
+ var aValue uint64 = 1
248
+ if strings .Contains (query , "uint64" ) {
249
+ aValue = math .MaxUint64
250
+ }
251
+
252
+ r , err = mysql .BuildSimpleResultset ([]string {"a" , "b" }, [][]interface {}{
253
+ {aValue , "hello world" },
254
+ }, binary )
255
+ }
173
256
}
174
257
175
258
if err != nil {
@@ -197,7 +280,7 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err
197
280
}
198
281
199
282
func (h * mockHandler ) HandleQuery (query string ) (* mysql.Result , error ) {
200
- return h .handleQuery (query , false )
283
+ return h .handleQuery (query , false , nil )
201
284
}
202
285
203
286
func (h * mockHandler ) HandleFieldList (table string , fieldWildcard string ) ([]* mysql.Field , error ) {
@@ -206,13 +289,13 @@ func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*my
206
289
207
290
func (h * mockHandler ) HandleStmtPrepare (query string ) (params int , columns int , context interface {}, err error ) {
208
291
params = 1
209
- columns = 0
292
+ columns = 2
210
293
return params , columns , nil , nil
211
294
}
212
295
213
296
func (h * mockHandler ) HandleStmtExecute (context interface {}, query string , args []interface {}) (* mysql.Result , error ) {
214
297
if strings .HasPrefix (strings .ToLower (query ), "select" ) {
215
- return h .HandleQuery (query )
298
+ return h .handleQuery (query , true , args )
216
299
}
217
300
218
301
return & mysql.Result {
0 commit comments