Skip to content

Commit 9fa33e2

Browse files
committed
implement ConnPrepareContext/StmtQueryContext/StmtExecContext interfaces
1 parent 2140507 commit 9fa33e2

File tree

4 files changed

+281
-2
lines changed

4 files changed

+281
-2
lines changed

conn.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,10 @@ func (st *stmt) Close() (err error) {
13601360
}
13611361

13621362
func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
1363+
return st.query(v)
1364+
}
1365+
1366+
func (st *stmt) query(v []driver.Value) (r *rows, err error) {
13631367
if st.cn.getBad() {
13641368
return nil, driver.ErrBadConn
13651369
}

conn_go18.go

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ import (
1111
"time"
1212
)
1313

14+
const (
15+
watchCancelDialContextTimeout = time.Second * 10
16+
)
17+
1418
// Implement the "QueryerContext" interface
1519
func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
1620
list := make([]driver.Value, len(args))
@@ -43,6 +47,14 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam
4347
return cn.Exec(query, list)
4448
}
4549

50+
// Implement the "ConnPrepareContext" interface
51+
func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
52+
if finish := cn.watchCancel(ctx); finish != nil {
53+
defer finish()
54+
}
55+
return cn.Prepare(query)
56+
}
57+
4658
// Implement the "ConnBeginTx" interface
4759
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
4860
var mode string
@@ -109,7 +121,7 @@ func (cn *conn) watchCancel(ctx context.Context) func() {
109121
// so it must not be used for the additional network
110122
// request to cancel the query.
111123
// Create a new context to pass into the dial.
112-
ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10)
124+
ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
113125
defer cancel()
114126

115127
_ = cn.cancel(ctxCancel)
@@ -172,3 +184,68 @@ func (cn *conn) cancel(ctx context.Context) error {
172184
return err
173185
}
174186
}
187+
188+
// Implement the "StmtQueryContext" interface
189+
func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
190+
list := make([]driver.Value, len(args))
191+
for i, nv := range args {
192+
list[i] = nv.Value
193+
}
194+
finish := st.watchCancel(ctx)
195+
r, err := st.query(list)
196+
if err != nil {
197+
if finish != nil {
198+
finish()
199+
}
200+
return nil, err
201+
}
202+
r.finish = finish
203+
return r, nil
204+
}
205+
206+
// Implement the "StmtExecContext" interface
207+
func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
208+
list := make([]driver.Value, len(args))
209+
for i, nv := range args {
210+
list[i] = nv.Value
211+
}
212+
213+
if finish := st.watchCancel(ctx); finish != nil {
214+
defer finish()
215+
}
216+
217+
return st.Exec(list)
218+
}
219+
220+
// watchCancel is implemented on stmt in order to not mark the parent conn as bad
221+
func (st *stmt) watchCancel(ctx context.Context) func() {
222+
if done := ctx.Done(); done != nil {
223+
finished := make(chan struct{})
224+
go func() {
225+
select {
226+
case <-done:
227+
// At this point the function level context is canceled,
228+
// so it must not be used for the additional network
229+
// request to cancel the query.
230+
// Create a new context to pass into the dial.
231+
ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
232+
defer cancel()
233+
234+
_ = st.cancel(ctxCancel)
235+
finished <- struct{}{}
236+
case <-finished:
237+
}
238+
}()
239+
return func() {
240+
select {
241+
case <-finished:
242+
case finished <- struct{}{}:
243+
}
244+
}
245+
}
246+
return nil
247+
}
248+
249+
func (st *stmt) cancel(ctx context.Context) error {
250+
return st.cn.cancel(ctx)
251+
}

conn_test.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,3 +1806,167 @@ func TestCopyInStmtAffectedRows(t *testing.T) {
18061806
res.RowsAffected()
18071807
res.LastInsertId()
18081808
}
1809+
1810+
func TestConnPrepareContext(t *testing.T) {
1811+
db := openTestConn(t)
1812+
defer db.Close()
1813+
1814+
tests := []struct {
1815+
name string
1816+
ctx func() (context.Context, context.CancelFunc)
1817+
sql string
1818+
err error
1819+
}{
1820+
{
1821+
name: "context.Background",
1822+
ctx: func() (context.Context, context.CancelFunc) {
1823+
return context.Background(), nil
1824+
},
1825+
sql: "SELECT 1",
1826+
err: nil,
1827+
},
1828+
{
1829+
name: "context.WithTimeout exceeded",
1830+
ctx: func() (context.Context, context.CancelFunc) {
1831+
return context.WithTimeout(context.Background(), time.Microsecond)
1832+
},
1833+
sql: "SELECT 1",
1834+
err: context.DeadlineExceeded,
1835+
},
1836+
{
1837+
name: "context.WithTimeout",
1838+
ctx: func() (context.Context, context.CancelFunc) {
1839+
return context.WithTimeout(context.Background(), time.Minute)
1840+
},
1841+
sql: "SELECT 1",
1842+
err: nil,
1843+
},
1844+
}
1845+
for _, tt := range tests {
1846+
t.Run(tt.name, func(t *testing.T) {
1847+
ctx, cancel := tt.ctx()
1848+
if cancel != nil {
1849+
defer cancel()
1850+
}
1851+
_, err := db.PrepareContext(ctx, tt.sql)
1852+
switch {
1853+
case (err != nil) != (tt.err != nil):
1854+
t.Fatalf("conn.PrepareContext() unexpected nil err got = %v, expected = %v", err, tt.err)
1855+
case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
1856+
t.Errorf("conn.PrepareContext() got = %v, expected = %v", err.Error(), tt.err.Error())
1857+
}
1858+
})
1859+
}
1860+
}
1861+
1862+
func TestStmtQueryContext(t *testing.T) {
1863+
db := openTestConn(t)
1864+
defer db.Close()
1865+
1866+
tests := []struct {
1867+
name string
1868+
ctx func() (context.Context, context.CancelFunc)
1869+
sql string
1870+
err error
1871+
}{
1872+
{
1873+
name: "context.Background",
1874+
ctx: func() (context.Context, context.CancelFunc) {
1875+
return context.Background(), nil
1876+
},
1877+
sql: "SELECT pg_sleep(1);",
1878+
err: nil,
1879+
},
1880+
{
1881+
name: "context.WithTimeout exceeded",
1882+
ctx: func() (context.Context, context.CancelFunc) {
1883+
return context.WithTimeout(context.Background(), 1*time.Second)
1884+
},
1885+
sql: "SELECT pg_sleep(10);",
1886+
err: &Error{Message: "canceling statement due to user request"},
1887+
},
1888+
{
1889+
name: "context.WithTimeout",
1890+
ctx: func() (context.Context, context.CancelFunc) {
1891+
return context.WithTimeout(context.Background(), time.Minute)
1892+
},
1893+
sql: "SELECT pg_sleep(1);",
1894+
err: nil,
1895+
},
1896+
}
1897+
for _, tt := range tests {
1898+
t.Run(tt.name, func(t *testing.T) {
1899+
ctx, cancel := tt.ctx()
1900+
if cancel != nil {
1901+
defer cancel()
1902+
}
1903+
stmt, err := db.PrepareContext(ctx, tt.sql)
1904+
if err != nil {
1905+
t.Fatal(err)
1906+
}
1907+
_, err = stmt.QueryContext(ctx)
1908+
switch {
1909+
case (err != nil) != (tt.err != nil):
1910+
t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, expected = %v", err, tt.err)
1911+
case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
1912+
t.Errorf("stmt.QueryContext() got = %v, expected = %v", err.Error(), tt.err.Error())
1913+
}
1914+
})
1915+
}
1916+
}
1917+
1918+
func TestStmtExecContext(t *testing.T) {
1919+
db := openTestConn(t)
1920+
defer db.Close()
1921+
1922+
tests := []struct {
1923+
name string
1924+
ctx func() (context.Context, context.CancelFunc)
1925+
sql string
1926+
err error
1927+
}{
1928+
{
1929+
name: "context.Background",
1930+
ctx: func() (context.Context, context.CancelFunc) {
1931+
return context.Background(), nil
1932+
},
1933+
sql: "SELECT pg_sleep(1);",
1934+
err: nil,
1935+
},
1936+
{
1937+
name: "context.WithTimeout exceeded",
1938+
ctx: func() (context.Context, context.CancelFunc) {
1939+
return context.WithTimeout(context.Background(), 1*time.Second)
1940+
},
1941+
sql: "SELECT pg_sleep(10);",
1942+
err: &Error{Message: "canceling statement due to user request"},
1943+
},
1944+
{
1945+
name: "context.WithTimeout",
1946+
ctx: func() (context.Context, context.CancelFunc) {
1947+
return context.WithTimeout(context.Background(), time.Minute)
1948+
},
1949+
sql: "SELECT pg_sleep(1);",
1950+
err: nil,
1951+
},
1952+
}
1953+
for _, tt := range tests {
1954+
t.Run(tt.name, func(t *testing.T) {
1955+
ctx, cancel := tt.ctx()
1956+
if cancel != nil {
1957+
defer cancel()
1958+
}
1959+
stmt, err := db.PrepareContext(ctx, tt.sql)
1960+
if err != nil {
1961+
t.Fatal(err)
1962+
}
1963+
_, err = stmt.ExecContext(ctx)
1964+
switch {
1965+
case (err != nil) != (tt.err != nil):
1966+
t.Fatalf("stmt.ExecContext() unexpected nil err got = %v, expected = %v", err, tt.err)
1967+
case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
1968+
t.Errorf("stmt.ExecContext() got = %v, expected = %v", err.Error(), tt.err.Error())
1969+
}
1970+
})
1971+
}
1972+
}

issues_test.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package pq
22

3-
import "testing"
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
)
48

59
func TestIssue494(t *testing.T) {
610
db := openTestConn(t)
@@ -24,3 +28,33 @@ func TestIssue494(t *testing.T) {
2428
t.Fatal("expected error")
2529
}
2630
}
31+
32+
func TestIssue1046(t *testing.T) {
33+
ctxTimeout := time.Second * 2
34+
35+
db := openTestConn(t)
36+
defer db.Close()
37+
38+
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
39+
defer cancel()
40+
41+
stmt, err := db.PrepareContext(ctx, `SELECT pg_sleep(10) AS id`)
42+
if err != nil {
43+
t.Fatal(err)
44+
}
45+
46+
var d []uint8
47+
err = stmt.QueryRowContext(ctx).Scan(&d)
48+
dl, _ := ctx.Deadline()
49+
since := time.Since(dl)
50+
if since > ctxTimeout {
51+
t.Logf("FAIL %s: query returned after context deadline: %v\n", t.Name(), since)
52+
t.Fail()
53+
}
54+
expectedErr := &Error{Message: "canceling statement due to user request"}
55+
if err == nil || err.Error() != expectedErr.Error() {
56+
t.Logf("ctx.Err(): [%T]%+v\n", ctx.Err(), ctx.Err())
57+
t.Logf("got err: [%T] %+v expected err: [%T] %+v", err, err, expectedErr, expectedErr)
58+
t.Fail()
59+
}
60+
}

0 commit comments

Comments
 (0)