Skip to content

Commit c607c3c

Browse files
authored
add support for driver.NamedValueChecker on driver connection (go-mysql-org#887)
* add support for driver.NamedValueChecker on driver connection * added more tests
1 parent b13191f commit c607c3c

File tree

3 files changed

+166
-11
lines changed

3 files changed

+166
-11
lines changed

README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,41 @@ func main() {
448448
}
449449
```
450450

451+
### Custom NamedValueChecker
452+
453+
Golang allows for custom handling of query arguments before they are passed to the driver
454+
with the implementation of a [NamedValueChecker](https://pkg.go.dev/database/sql/driver#NamedValueChecker). By doing a full import of the driver (not by side-effects only),
455+
a custom NamedValueChecker can be implemented.
456+
457+
```golang
458+
import (
459+
"database/sql"
460+
461+
"github.com/go-mysql-org/go-mysql/driver"
462+
)
463+
464+
func main() {
465+
driver.AddNamedValueChecker(func(nv *sqlDriver.NamedValue) error {
466+
rv := reflect.ValueOf(nv.Value)
467+
if rv.Kind() != reflect.Uint64 {
468+
// fallback to the default value converter when the value is not a uint64
469+
return sqlDriver.ErrSkip
470+
}
471+
472+
return nil
473+
})
474+
475+
conn, err := sql.Open("mysql", "root@127.0.0.1:3306/test")
476+
defer conn.Close()
477+
478+
stmt, err := conn.Prepare("select * from table where id = ?")
479+
defer stmt.Close()
480+
var val uint64 = math.MaxUint64
481+
// without the NamedValueChecker this query would fail
482+
result, err := stmt.Query(val)
483+
}
484+
```
485+
451486

452487
We pass all tests in https://github.com/bradfitz/go-sql-test using go-mysql driver. :-)
453488

driver/driver.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"crypto/tls"
77
"database/sql"
88
sqldriver "database/sql/driver"
9+
goErrors "errors"
910
"fmt"
1011
"io"
1112
"net/url"
@@ -25,6 +26,10 @@ var customTLSMutex sync.Mutex
2526
var (
2627
customTLSConfigMap = make(map[string]*tls.Config)
2728
options = make(map[string]DriverOption)
29+
30+
// can be provided by clients to allow more control in handling Go and database
31+
// types beyond the default Value types allowed
32+
namedValueCheckers []CheckNamedValueFunc
2833
)
2934

3035
type driver struct {
@@ -154,10 +159,32 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
154159
return &conn{c}, nil
155160
}
156161

162+
type CheckNamedValueFunc func(*sqldriver.NamedValue) error
163+
164+
var _ sqldriver.NamedValueChecker = &conn{}
165+
157166
type conn struct {
158167
*client.Conn
159168
}
160169

170+
func (c *conn) CheckNamedValue(nv *sqldriver.NamedValue) error {
171+
for _, nvChecker := range namedValueCheckers {
172+
err := nvChecker(nv)
173+
if err == nil {
174+
// we've found a CheckNamedValueFunc that handled this named value
175+
// no need to keep looking
176+
return nil
177+
} else {
178+
// we've found an error, if the error is driver.ErrSkip then
179+
// keep looking otherwise return the unknown error
180+
if !goErrors.Is(sqldriver.ErrSkip, err) {
181+
return err
182+
}
183+
}
184+
}
185+
return sqldriver.ErrSkip
186+
}
187+
161188
func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
162189
st, err := c.Conn.Prepare(query)
163190
if err != nil {
@@ -375,3 +402,13 @@ func SetDSNOptions(customOptions map[string]DriverOption) {
375402
options[o] = f
376403
}
377404
}
405+
406+
// AddNamedValueChecker sets a custom NamedValueChecker for the driver connection which
407+
// allows for more control in handling Go and database types beyond the default Value types.
408+
// See https://pkg.go.dev/database/sql/driver#NamedValueChecker
409+
// Usage requires a full import of the driver (not by side-effects only). Also note that
410+
// this function is not concurrent-safe, and should only be executed while setting up the driver
411+
// before establishing any connections via `sql.Open()`.
412+
func AddNamedValueChecker(nvCheckFunc ...CheckNamedValueFunc) {
413+
namedValueCheckers = append(namedValueCheckers, nvCheckFunc...)
414+
}

driver/driver_options_test.go

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ package driver
33
import (
44
"context"
55
"database/sql"
6+
sqlDriver "database/sql/driver"
67
"fmt"
8+
"math"
79
"net"
10+
"reflect"
811
"strings"
912
"testing"
1013
"time"
@@ -94,16 +97,85 @@ func TestDriverOptions_writeTimeout(t *testing.T) {
9497
srv := CreateMockServer(t)
9598
defer srv.Stop()
9699

100+
// use a writeTimeout that will fail parsing by ParseDuration resulting
101+
// in a conn open error. The Open() won't fail until the ExecContext()
102+
// call though, because that's when golang database/sql package will open
103+
// the actual connection.
97104
conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=10")
98105
require.NoError(t, err)
106+
require.NotNil(t, conn)
99107

100-
result, err := conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);")
108+
// here we should fail because of the missing time unit in the duration.
109+
result, err := conn.ExecContext(context.TODO(), "select 1;")
110+
require.Contains(t, err.Error(), "missing unit in duration")
111+
require.Error(t, err)
101112
require.Nil(t, result)
113+
114+
// use an almost zero (1ns) writeTimeout to ensure the insert statement
115+
// can't write before the timeout. Just want to make sure ExecContext()
116+
// will throw an error.
117+
conn, err = sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=1ns")
118+
require.NoError(t, err)
119+
120+
// ExecContext() should fail due to the write timeout of 1ns
121+
result, err = conn.ExecContext(context.TODO(), "insert into slow(a,b) values(1,2);")
102122
require.Error(t, err)
123+
require.Contains(t, err.Error(), "i/o timeout")
124+
require.Nil(t, result)
103125

104126
conn.Close()
105127
}
106128

129+
func TestDriverOptions_namedValueChecker(t *testing.T) {
130+
AddNamedValueChecker(func(nv *sqlDriver.NamedValue) error {
131+
rv := reflect.ValueOf(nv.Value)
132+
if rv.Kind() != reflect.Uint64 {
133+
// fallback to the default value converter when the value is not a uint64
134+
return sqlDriver.ErrSkip
135+
}
136+
137+
return nil
138+
})
139+
140+
log.SetLevel(log.LevelDebug)
141+
srv := CreateMockServer(t)
142+
defer srv.Stop()
143+
conn, err := sql.Open("mysql", "root@127.0.0.1:3307/test?writeTimeout=1s")
144+
require.NoError(t, err)
145+
defer conn.Close()
146+
147+
// the NamedValueChecker will return ErrSkip for types that are NOT uint64, so make
148+
// sure those make it to the server ok first.
149+
int32Stmt, err := conn.Prepare("select ?")
150+
require.NoError(t, err)
151+
defer int32Stmt.Close()
152+
r1, err := int32Stmt.Query(math.MaxInt32)
153+
require.NoError(t, err)
154+
require.NotNil(t, r1)
155+
156+
var i32 int32
157+
require.True(t, r1.Next())
158+
require.NoError(t, r1.Scan(&i32))
159+
require.True(t, math.MaxInt32 == i32)
160+
161+
// Now make sure that the uint64 makes it to the server as well, this case will be handled
162+
// by the NamedValueChecker (i.e. it will not return ErrSkip)
163+
stmt, err := conn.Prepare("select a, b from fast where uint64 = ?")
164+
require.NoError(t, err)
165+
defer stmt.Close()
166+
167+
var val uint64 = math.MaxUint64
168+
result, err := stmt.Query(val)
169+
require.NoError(t, err)
170+
require.NotNil(t, result)
171+
172+
var a uint64
173+
var b string
174+
require.True(t, result.Next())
175+
require.NoError(t, result.Scan(&a, &b))
176+
require.True(t, math.MaxUint64 == a)
177+
}
178+
107179
func CreateMockServer(t *testing.T) *testServer {
108180
inMemProvider := server.NewInMemoryProvider()
109181
inMemProvider.AddUser(*testUser, *testPassword)
@@ -151,7 +223,7 @@ func (h *mockHandler) UseDB(dbName string) error {
151223
return nil
152224
}
153225

154-
func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, error) {
226+
func (h *mockHandler) handleQuery(query string, binary bool, args []interface{}) (*mysql.Result, error) {
155227
ss := strings.Split(query, " ")
156228
switch strings.ToLower(ss[0]) {
157229
case "select":
@@ -163,13 +235,24 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err
163235
{mysql.MaxPayloadLen},
164236
}, binary)
165237
} else {
166-
if strings.Contains(query, "slow") {
167-
time.Sleep(time.Second * 5)
168-
}
238+
if ss[1] == "?" {
239+
r, err = mysql.BuildSimpleResultset([]string{"a"}, [][]interface{}{
240+
{args[0].(int64)},
241+
}, binary)
242+
} else {
243+
if strings.Contains(query, "slow") {
244+
time.Sleep(time.Second * 5)
245+
}
169246

170-
r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{
171-
{1, "hello world"},
172-
}, binary)
247+
var aValue uint64 = 1
248+
if strings.Contains(query, "uint64") {
249+
aValue = math.MaxUint64
250+
}
251+
252+
r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{
253+
{aValue, "hello world"},
254+
}, binary)
255+
}
173256
}
174257

175258
if err != nil {
@@ -197,7 +280,7 @@ func (h *mockHandler) handleQuery(query string, binary bool) (*mysql.Result, err
197280
}
198281

199282
func (h *mockHandler) HandleQuery(query string) (*mysql.Result, error) {
200-
return h.handleQuery(query, false)
283+
return h.handleQuery(query, false, nil)
201284
}
202285

203286
func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) {
@@ -206,13 +289,13 @@ func (h *mockHandler) HandleFieldList(table string, fieldWildcard string) ([]*my
206289

207290
func (h *mockHandler) HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error) {
208291
params = 1
209-
columns = 0
292+
columns = 2
210293
return params, columns, nil, nil
211294
}
212295

213296
func (h *mockHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) {
214297
if strings.HasPrefix(strings.ToLower(query), "select") {
215-
return h.HandleQuery(query)
298+
return h.handleQuery(query, true, args)
216299
}
217300

218301
return &mysql.Result{

0 commit comments

Comments
 (0)