Skip to content

Commit adbb38f

Browse files
committed
Do not allow protocol messages larger than ~1GB
The PostgreSQL server will reject messages greater than ~1 GB anyway. However, worse than that is that a message that is larger than 4 GB could wrap the 32-bit integer message size and be interpreted by the server as multiple messages. This could allow a malicious client to inject arbitrary protocol messages. GHSA-mrww-27vc-gghv
1 parent c1b0a01 commit adbb38f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+472
-390
lines changed

pgconn/pgconn.go

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
16741674
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
16751675
type Batch struct {
16761676
buf []byte
1677+
err error
16771678
}
16781679

16791680
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
16801681
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
1681-
batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
1682+
if batch.err != nil {
1683+
return
1684+
}
1685+
1686+
batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
1687+
if batch.err != nil {
1688+
return
1689+
}
16821690
batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
16831691
}
16841692

16851693
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
16861694
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
1687-
batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
1688-
batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
1689-
batch.buf = (&pgproto3.Execute{}).Encode(batch.buf)
1695+
if batch.err != nil {
1696+
return
1697+
}
1698+
1699+
batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
1700+
if batch.err != nil {
1701+
return
1702+
}
1703+
1704+
batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
1705+
if batch.err != nil {
1706+
return
1707+
}
1708+
1709+
batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
1710+
if batch.err != nil {
1711+
return
1712+
}
16901713
}
16911714

16921715
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
16931716
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
16941717
// multiple queries in a single round trip than using pipeline mode.
16951718
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
1719+
if batch.err != nil {
1720+
return &MultiResultReader{
1721+
closed: true,
1722+
err: batch.err,
1723+
}
1724+
}
1725+
16961726
if err := pgConn.lock(); err != nil {
16971727
return &MultiResultReader{
16981728
closed: true,
@@ -1718,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
17181748
pgConn.contextWatcher.Watch(ctx)
17191749
}
17201750

1721-
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
1751+
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
1752+
if batch.err != nil {
1753+
multiResult.closed = true
1754+
multiResult.err = batch.err
1755+
pgConn.unlock()
1756+
return multiResult
1757+
}
17221758

17231759
pgConn.enterPotentialWriteReadDeadlock()
17241760
defer pgConn.exitPotentialWriteReadDeadlock()

pgconn/pgconn_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3363,9 +3363,9 @@ func TestSNISupport(t *testing.T) {
33633363
return
33643364
}
33653365

3366-
srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
3367-
srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
3368-
srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
3366+
srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
3367+
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
3368+
srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))
33693369

33703370
serverSNINameChan <- sniHost
33713371
}()
@@ -3472,3 +3472,10 @@ func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
34723472
err = pipeline.Close()
34733473
require.Error(t, err)
34743474
}
3475+
3476+
func mustEncode(buf []byte, err error) []byte {
3477+
if err != nil {
3478+
panic(err)
3479+
}
3480+
return buf
3481+
}

pgproto3/authentication_cleartext_password.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
3535
}
3636

3737
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
38-
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
39-
dst = append(dst, 'R')
40-
dst = pgio.AppendInt32(dst, 8)
38+
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
39+
dst, sp := beginMessage(dst, 'R')
4140
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
42-
return dst
41+
return finishMessage(dst, sp)
4342
}
4443

4544
// MarshalJSON implements encoding/json.Marshaler.

pgproto3/authentication_gss.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
2727
return nil
2828
}
2929

30-
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
31-
dst = append(dst, 'R')
32-
dst = pgio.AppendInt32(dst, 4)
30+
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
31+
dst, sp := beginMessage(dst, 'R')
3332
dst = pgio.AppendUint32(dst, AuthTypeGSS)
34-
return dst
33+
return finishMessage(dst, sp)
3534
}
3635

3736
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {

pgproto3/authentication_gss_continue.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
3131
return nil
3232
}
3333

34-
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
35-
dst = append(dst, 'R')
36-
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
34+
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
35+
dst, sp := beginMessage(dst, 'R')
3736
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
3837
dst = append(dst, a.Data...)
39-
return dst
38+
return finishMessage(dst, sp)
4039
}
4140

4241
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {

pgproto3/authentication_md5_password.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
3838
}
3939

4040
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
41-
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
42-
dst = append(dst, 'R')
43-
dst = pgio.AppendInt32(dst, 12)
41+
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
42+
dst, sp := beginMessage(dst, 'R')
4443
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
4544
dst = append(dst, src.Salt[:]...)
46-
return dst
45+
return finishMessage(dst, sp)
4746
}
4847

4948
// MarshalJSON implements encoding/json.Marshaler.

pgproto3/authentication_ok.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
3535
}
3636

3737
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
38-
func (src *AuthenticationOk) Encode(dst []byte) []byte {
39-
dst = append(dst, 'R')
40-
dst = pgio.AppendInt32(dst, 8)
38+
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
39+
dst, sp := beginMessage(dst, 'R')
4140
dst = pgio.AppendUint32(dst, AuthTypeOk)
42-
return dst
41+
return finishMessage(dst, sp)
4342
}
4443

4544
// MarshalJSON implements encoding/json.Marshaler.

pgproto3/authentication_sasl.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
4747
}
4848

4949
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
50-
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
51-
dst = append(dst, 'R')
52-
sp := len(dst)
53-
dst = pgio.AppendInt32(dst, -1)
50+
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
51+
dst, sp := beginMessage(dst, 'R')
5452
dst = pgio.AppendUint32(dst, AuthTypeSASL)
5553

5654
for _, s := range src.AuthMechanisms {
@@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
5957
}
6058
dst = append(dst, 0)
6159

62-
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
63-
64-
return dst
60+
return finishMessage(dst, sp)
6561
}
6662

6763
// MarshalJSON implements encoding/json.Marshaler.

pgproto3/authentication_sasl_continue.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
3838
}
3939

4040
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
41-
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
42-
dst = append(dst, 'R')
43-
sp := len(dst)
44-
dst = pgio.AppendInt32(dst, -1)
41+
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
42+
dst, sp := beginMessage(dst, 'R')
4543
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
46-
4744
dst = append(dst, src.Data...)
48-
49-
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
50-
51-
return dst
45+
return finishMessage(dst, sp)
5246
}
5347

5448
// MarshalJSON implements encoding/json.Marshaler.

pgproto3/authentication_sasl_final.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
3838
}
3939

4040
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
41-
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
42-
dst = append(dst, 'R')
43-
sp := len(dst)
44-
dst = pgio.AppendInt32(dst, -1)
41+
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
42+
dst, sp := beginMessage(dst, 'R')
4543
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
46-
4744
dst = append(dst, src.Data...)
48-
49-
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
50-
51-
return dst
45+
return finishMessage(dst, sp)
5246
}
5347

5448
// MarshalJSON implements encoding/json.Unmarshaler.

pgproto3/backend.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ type Backend struct {
1616
// before it is actually transmitted (i.e. before Flush).
1717
tracer *tracer
1818

19-
wbuf []byte
19+
wbuf []byte
20+
encodeError error
2021

2122
// Frontend message flyweights
2223
bind Bind
@@ -55,18 +56,34 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
5556
return &Backend{cr: cr, w: w}
5657
}
5758

58-
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
59-
// called.
59+
// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
60+
// encountered will be returned from Flush.
6061
func (b *Backend) Send(msg BackendMessage) {
62+
if b.encodeError != nil {
63+
return
64+
}
65+
6166
prevLen := len(b.wbuf)
62-
b.wbuf = msg.Encode(b.wbuf)
67+
newBuf, err := msg.Encode(b.wbuf)
68+
if err != nil {
69+
b.encodeError = err
70+
return
71+
}
72+
b.wbuf = newBuf
73+
6374
if b.tracer != nil {
6475
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
6576
}
6677
}
6778

6879
// Flush writes any pending messages to the frontend (i.e. the client).
6980
func (b *Backend) Flush() error {
81+
if err := b.encodeError; err != nil {
82+
b.encodeError = nil
83+
b.wbuf = b.wbuf[:0]
84+
return &writeError{err: err, safeToRetry: true}
85+
}
86+
7087
n, err := b.w.Write(b.wbuf)
7188

7289
const maxLen = 1024

pgproto3/backend_key_data.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
2929
}
3030

3131
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
32-
func (src *BackendKeyData) Encode(dst []byte) []byte {
33-
dst = append(dst, 'K')
34-
dst = pgio.AppendUint32(dst, 12)
32+
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
33+
dst, sp := beginMessage(dst, 'K')
3534
dst = pgio.AppendUint32(dst, src.ProcessID)
3635
dst = pgio.AppendUint32(dst, src.SecretKey)
37-
return dst
36+
return finishMessage(dst, sp)
3837
}
3938

4039
// MarshalJSON implements encoding/json.Marshaler.

pgproto3/backend_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
7171
"username": "tester",
7272
},
7373
}
74-
dst := []byte{}
75-
dst = want.Encode(dst)
74+
dst, err := want.Encode([]byte{})
75+
require.NoError(t, err)
7676

7777
server := &interruptReader{}
7878
server.push(dst)

pgproto3/bind.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,8 @@ func (dst *Bind) Decode(src []byte) error {
108108
}
109109

110110
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
111-
func (src *Bind) Encode(dst []byte) []byte {
112-
dst = append(dst, 'B')
113-
sp := len(dst)
114-
dst = pgio.AppendInt32(dst, -1)
111+
func (src *Bind) Encode(dst []byte) ([]byte, error) {
112+
dst, sp := beginMessage(dst, 'B')
115113

116114
dst = append(dst, src.DestinationPortal...)
117115
dst = append(dst, 0)
@@ -139,9 +137,7 @@ func (src *Bind) Encode(dst []byte) []byte {
139137
dst = pgio.AppendInt16(dst, fc)
140138
}
141139

142-
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
143-
144-
return dst
140+
return finishMessage(dst, sp)
145141
}
146142

147143
// MarshalJSON implements encoding/json.Marshaler.

pgproto3/bind_complete.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
2020
}
2121

2222
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
23-
func (src *BindComplete) Encode(dst []byte) []byte {
24-
return append(dst, '2', 0, 0, 0, 4)
23+
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
24+
return append(dst, '2', 0, 0, 0, 4), nil
2525
}
2626

2727
// MarshalJSON implements encoding/json.Marshaler.

pgproto3/bind_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package pgproto3_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/jackc/pgx/v5/pgproto3"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
11+
t.Parallel()
12+
13+
// Maximum allowed size.
14+
_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
15+
require.NoError(t, err)
16+
17+
// 1 byte too big
18+
_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
19+
require.Error(t, err)
20+
}

pgproto3/cancel_request.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
3636
}
3737

3838
// Encode encodes src into dst. dst will include the 4 byte message length.
39-
func (src *CancelRequest) Encode(dst []byte) []byte {
39+
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
4040
dst = pgio.AppendInt32(dst, 16)
4141
dst = pgio.AppendInt32(dst, cancelRequestCode)
4242
dst = pgio.AppendUint32(dst, src.ProcessID)
4343
dst = pgio.AppendUint32(dst, src.SecretKey)
44-
return dst
44+
return dst, nil
4545
}
4646

4747
// MarshalJSON implements encoding/json.Marshaler.

0 commit comments

Comments
 (0)