Skip to content

Commit a46ee34

Browse files
author
Fabian Holler
committed
allow retry on netErrors in safe situations
In some situations the sql package does not retry a pq operation when it should. One of the situations is #870. When a postgresql-server is restarted and after the restart is finished an operation is triggered on the already established connection, it failed with an broken pipe error in some circumstances. The sql package does not retry the operation and instead fail because the pq driver does not return driver.ErrBadConn for network errors. The driver must not return ErrBadConn when the server might have already executed the operation. This would cause that sql package is retrying it and the operation would be run multiple times by the postgresql server. In some situations it's safe to return ErrBadConn on network errors. This is the case when it's ensured that the server did not receive the message that triggers the operation. This commit introduces a netErrorNoWrite error. This error should be used when network operations panic when it's safe to retry the operation. When errRecover() receives this error it returns ErrBadConn() and marks the connection as bad. A mustSendRetryable() function is introduced that wraps a netOpError in an netErrorNoWrite when panicing. mustSendRetryable() is called in situations when the send that triggers the operation failed.
1 parent 2ff3cb3 commit a46ee34

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

conn.go

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ func (cn *conn) gname() string {
600600
func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
601601
b := cn.writeBuf('Q')
602602
b.string(q)
603-
cn.send(b)
603+
cn.mustSendRetryable(b)
604604

605605
for {
606606
t, r := cn.recv1()
@@ -632,7 +632,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
632632

633633
b := cn.writeBuf('Q')
634634
b.string(q)
635-
cn.send(b)
635+
cn.mustSendRetryable(b)
636636

637637
for {
638638
t, r := cn.recv1()
@@ -765,7 +765,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt {
765765
b.string(st.name)
766766

767767
b.next('S')
768-
cn.send(b)
768+
cn.mustSendRetryable(b)
769769

770770
cn.readParseResponse()
771771
st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
@@ -882,13 +882,28 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
882882
return r, err
883883
}
884884

885-
func (cn *conn) send(m *writeBuf) {
885+
func (cn *conn) send(m *writeBuf) error {
886886
_, err := cn.c.Write(m.wrap())
887+
return err
888+
}
889+
890+
func (cn *conn) mustSend(m *writeBuf) {
891+
err := cn.send(m)
887892
if err != nil {
888893
panic(err)
889894
}
890895
}
891896

897+
func (cn *conn) mustSendRetryable(m *writeBuf) {
898+
err := cn.send(m)
899+
if err != nil {
900+
if _, ok := err.(*net.OpError); ok {
901+
err = &netErrorNoWrite{err}
902+
}
903+
panic(err)
904+
}
905+
}
906+
892907
func (cn *conn) sendStartupPacket(m *writeBuf) error {
893908
_, err := cn.c.Write((m.wrap())[1:])
894909
return err
@@ -1109,7 +1124,7 @@ func (cn *conn) auth(r *readBuf, o values) {
11091124
case 3:
11101125
w := cn.writeBuf('p')
11111126
w.string(o["password"])
1112-
cn.send(w)
1127+
cn.mustSend(w)
11131128

11141129
t, r := cn.recv()
11151130
if t != 'R' {
@@ -1123,7 +1138,7 @@ func (cn *conn) auth(r *readBuf, o values) {
11231138
s := string(r.next(4))
11241139
w := cn.writeBuf('p')
11251140
w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
1126-
cn.send(w)
1141+
cn.mustSend(w)
11271142

11281143
t, r := cn.recv()
11291144
if t != 'R' {
@@ -1145,7 +1160,7 @@ func (cn *conn) auth(r *readBuf, o values) {
11451160
w.string("SCRAM-SHA-256")
11461161
w.int32(len(scOut))
11471162
w.bytes(scOut)
1148-
cn.send(w)
1163+
cn.mustSend(w)
11491164

11501165
t, r := cn.recv()
11511166
if t != 'R' {
@@ -1165,7 +1180,7 @@ func (cn *conn) auth(r *readBuf, o values) {
11651180
scOut = sc.Out()
11661181
w = cn.writeBuf('p')
11671182
w.bytes(scOut)
1168-
cn.send(w)
1183+
cn.mustSend(w)
11691184

11701185
t, r = cn.recv()
11711186
if t != 'R' {
@@ -1219,9 +1234,9 @@ func (st *stmt) Close() (err error) {
12191234
w := st.cn.writeBuf('C')
12201235
w.byte('S')
12211236
w.string(st.name)
1222-
st.cn.send(w)
1237+
st.cn.mustSend(w)
12231238

1224-
st.cn.send(st.cn.writeBuf('S'))
1239+
st.cn.mustSend(st.cn.writeBuf('S'))
12251240

12261241
t, _ := st.cn.recv1()
12271242
if t != '3' {
@@ -1299,7 +1314,7 @@ func (st *stmt) exec(v []driver.Value) {
12991314
w.int32(0)
13001315

13011316
w.next('S')
1302-
cn.send(w)
1317+
cn.mustSend(w)
13031318

13041319
cn.readBindResponse()
13051320
cn.postExecuteWorkaround()
@@ -1601,7 +1616,7 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
16011616
b.int32(0)
16021617

16031618
b.next('S')
1604-
cn.send(b)
1619+
cn.mustSendRetryable(b)
16051620
}
16061621

16071622
func (cn *conn) processParameterStatus(r *readBuf) {

error.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,18 @@ func errorf(s string, args ...interface{}) {
460460
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
461461
}
462462

463+
// NetErrorNoWrite is a network error that occured before a message that
464+
// indicates the operation to execute was transfered to the server.
465+
// These operations are safe to retry and driver.ErrBadConn should be passed
466+
// to the caller.
467+
type netErrorNoWrite struct {
468+
Err error
469+
}
470+
471+
func (e *netErrorNoWrite) Error() string {
472+
return "netErrorNoWrite: " + e.Err.Error()
473+
}
474+
463475
// TODO(ainar-g) Rename to errorf after removing panics.
464476
func fmterrorf(s string, args ...interface{}) error {
465477
return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))
@@ -492,6 +504,9 @@ func (c *conn) errRecover(err *error) {
492504
} else {
493505
*err = v
494506
}
507+
case *netErrorNoWrite:
508+
c.bad = true
509+
*err = driver.ErrBadConn
495510
case *net.OpError:
496511
c.bad = true
497512
*err = v

0 commit comments

Comments
 (0)