Skip to content

Commit 389c014

Browse files
committed
first test working against both sqlite and mysql.
1 parent 4f43c19 commit 389c014

File tree

1 file changed

+73
-9
lines changed

1 file changed

+73
-9
lines changed

src/sqltest/sql_test.go

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,33 @@ import (
1212
_ "github.com/ziutek/mymysql/godrv"
1313
)
1414

15-
func testSQLite(t *testing.T, fn func(*testing.T, *sql.DB)) {
15+
type Tester interface {
16+
RunTest(*testing.T, func(params))
17+
}
18+
19+
type mysqlDB int
20+
type sqliteDB int
21+
22+
var (
23+
mysql = mysqlDB(1)
24+
sqlite = sqliteDB(1)
25+
)
26+
27+
type params struct {
28+
dbType Tester
29+
*testing.T
30+
*sql.DB
31+
}
32+
33+
func (t params) mustExec(sql string, args ...interface{}) sql.Result {
34+
res, err := t.DB.Exec(sql, args...)
35+
if err != nil {
36+
t.Fatalf("Error running %q: %v", sql, err)
37+
}
38+
return res
39+
}
40+
41+
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
1642
tempDir, err := ioutil.TempDir("", "")
1743
if err != nil {
1844
t.Fatal(err)
@@ -22,30 +48,68 @@ func testSQLite(t *testing.T, fn func(*testing.T, *sql.DB)) {
2248
if err != nil {
2349
t.Fatalf("foo.db open fail: %v", err)
2450
}
25-
fn(t, db)
51+
fn(params{sqlite, t, db})
52+
}
53+
54+
func (mysqlDB) RunTest(t *testing.T, fn func(params)) {
55+
user := os.Getenv("GOSQLTEST_MYSQL_USER")
56+
if user == "" {
57+
user = "root"
58+
}
59+
pass, err := os.Getenverror("GOSQLTEST_MYSQL_PASS")
60+
if err != nil {
61+
pass = "root"
62+
}
63+
dbName := "gosqltest"
64+
db, err := sql.Open("mymysql", fmt.Sprintf("%s/%s/%s", dbName, user, pass))
65+
if err != nil {
66+
t.Fatalf("error connecting: %v", err)
67+
}
68+
69+
params := params{mysql, t, db}
70+
71+
// Drop all tables in the test database.
72+
rows, err := db.Query("SHOW TABLES")
73+
if err != nil {
74+
t.Fatalf("failed to enumerate tables: %v", err)
75+
}
76+
for rows.Next() {
77+
var table string
78+
if rows.Scan(&table) == nil {
79+
params.mustExec("DROP TABLE " + table)
80+
}
81+
}
82+
83+
fn(params)
2684
}
2785

28-
func TestBlobs_SQLite(t *testing.T) {
29-
testSQLite(t, testBlobs)
86+
func sqlBlobParam(t params, size int) string {
87+
if t.dbType == sqlite {
88+
return fmt.Sprintf("blob[%d]", size)
89+
}
90+
return fmt.Sprintf("VARBINARY(%d)", size)
3091
}
3192

32-
func testBlobs(t *testing.T, db *sql.DB) {
93+
func TestBlobs_SQLite(t *testing.T) { sqlite.RunTest(t, testBlobs) }
94+
func TestBlobs_MySQL(t *testing.T) { mysql.RunTest(t, testBlobs) }
95+
96+
func testBlobs(t params) {
3397
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
34-
db.Exec("create table foo (id integer primary key, bar blob[16])")
35-
db.Exec("insert or replace into foo (id, bar) values(?,?)", 0, blob)
98+
t.mustExec("create table foo (id integer primary key, bar " + sqlBlobParam(t, 16) + ")")
99+
t.mustExec("replace into foo (id, bar) values(?,?)", 0, blob)
36100

37101
want := fmt.Sprintf("%x", blob)
38102

39103
b := make([]byte, 16)
40-
err := db.QueryRow("select bar from foo where id = ?", 0).Scan(&b)
104+
err := t.QueryRow("select bar from foo where id = ?", 0).Scan(&b)
41105
got := fmt.Sprintf("%x", b)
42106
if err != nil {
43107
t.Errorf("[]byte scan: %v", err)
44108
} else if got != want {
45109
t.Errorf("for []byte, got %q; want %q", got, want)
46110
}
47111

48-
err = db.QueryRow("select bar from foo where id = ?", 0).Scan(&got)
112+
err = t.QueryRow("select bar from foo where id = ?", 0).Scan(&got)
49113
want = string(blob)
50114
if err != nil {
51115
t.Errorf("string scan: %v", err)

0 commit comments

Comments
 (0)