@@ -12,7 +12,33 @@ import (
12
12
_ "github.com/ziutek/mymysql/godrv"
13
13
)
14
14
15
- func testSQLite (t * testing.T , fn func (* testing.T , * sql.DB )) {
15
+ type Tester interface {
16
+ RunTest (* testing.T , func (params ))
17
+ }
18
+
19
+ type mysqlDB int
20
+ type sqliteDB int
21
+
22
+ var (
23
+ mysql = mysqlDB (1 )
24
+ sqlite = sqliteDB (1 )
25
+ )
26
+
27
+ type params struct {
28
+ dbType Tester
29
+ * testing.T
30
+ * sql.DB
31
+ }
32
+
33
+ func (t params ) mustExec (sql string , args ... interface {}) sql.Result {
34
+ res , err := t .DB .Exec (sql , args ... )
35
+ if err != nil {
36
+ t .Fatalf ("Error running %q: %v" , sql , err )
37
+ }
38
+ return res
39
+ }
40
+
41
+ func (sqliteDB ) RunTest (t * testing.T , fn func (params )) {
16
42
tempDir , err := ioutil .TempDir ("" , "" )
17
43
if err != nil {
18
44
t .Fatal (err )
@@ -22,30 +48,68 @@ func testSQLite(t *testing.T, fn func(*testing.T, *sql.DB)) {
22
48
if err != nil {
23
49
t .Fatalf ("foo.db open fail: %v" , err )
24
50
}
25
- fn (t , db )
51
+ fn (params {sqlite , t , db })
52
+ }
53
+
54
+ func (mysqlDB ) RunTest (t * testing.T , fn func (params )) {
55
+ user := os .Getenv ("GOSQLTEST_MYSQL_USER" )
56
+ if user == "" {
57
+ user = "root"
58
+ }
59
+ pass , err := os .Getenverror ("GOSQLTEST_MYSQL_PASS" )
60
+ if err != nil {
61
+ pass = "root"
62
+ }
63
+ dbName := "gosqltest"
64
+ db , err := sql .Open ("mymysql" , fmt .Sprintf ("%s/%s/%s" , dbName , user , pass ))
65
+ if err != nil {
66
+ t .Fatalf ("error connecting: %v" , err )
67
+ }
68
+
69
+ params := params {mysql , t , db }
70
+
71
+ // Drop all tables in the test database.
72
+ rows , err := db .Query ("SHOW TABLES" )
73
+ if err != nil {
74
+ t .Fatalf ("failed to enumerate tables: %v" , err )
75
+ }
76
+ for rows .Next () {
77
+ var table string
78
+ if rows .Scan (& table ) == nil {
79
+ params .mustExec ("DROP TABLE " + table )
80
+ }
81
+ }
82
+
83
+ fn (params )
26
84
}
27
85
28
- func TestBlobs_SQLite (t * testing.T ) {
29
- testSQLite (t , testBlobs )
86
+ func sqlBlobParam (t params , size int ) string {
87
+ if t .dbType == sqlite {
88
+ return fmt .Sprintf ("blob[%d]" , size )
89
+ }
90
+ return fmt .Sprintf ("VARBINARY(%d)" , size )
30
91
}
31
92
32
- func testBlobs (t * testing.T , db * sql.DB ) {
93
+ func TestBlobs_SQLite (t * testing.T ) { sqlite .RunTest (t , testBlobs ) }
94
+ func TestBlobs_MySQL (t * testing.T ) { mysql .RunTest (t , testBlobs ) }
95
+
96
+ func testBlobs (t params ) {
33
97
var blob = []byte {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }
34
- db . Exec ("create table foo (id integer primary key, bar blob[16] )" )
35
- db . Exec ( "insert or replace into foo (id, bar) values(?,?)" , 0 , blob )
98
+ t . mustExec ("create table foo (id integer primary key, bar " + sqlBlobParam ( t , 16 ) + " )" )
99
+ t . mustExec ( " replace into foo (id, bar) values(?,?)" , 0 , blob )
36
100
37
101
want := fmt .Sprintf ("%x" , blob )
38
102
39
103
b := make ([]byte , 16 )
40
- err := db .QueryRow ("select bar from foo where id = ?" , 0 ).Scan (& b )
104
+ err := t .QueryRow ("select bar from foo where id = ?" , 0 ).Scan (& b )
41
105
got := fmt .Sprintf ("%x" , b )
42
106
if err != nil {
43
107
t .Errorf ("[]byte scan: %v" , err )
44
108
} else if got != want {
45
109
t .Errorf ("for []byte, got %q; want %q" , got , want )
46
110
}
47
111
48
- err = db .QueryRow ("select bar from foo where id = ?" , 0 ).Scan (& got )
112
+ err = t .QueryRow ("select bar from foo where id = ?" , 0 ).Scan (& got )
49
113
want = string (blob )
50
114
if err != nil {
51
115
t .Errorf ("string scan: %v" , err )
0 commit comments