-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
grpc: fix receiving empty messages when compression is enabled and maxReceiveMessageSize is MaxInt #7753 #7918
base: master
Are you sure you want to change the base?
Changes from 1 commit
b08845a
a02e58c
54da70d
4efd6fe
851d13a
26bf731
3cf9054
27a68a0
3daef9b
8815cbd
5541486
a21e192
09132bf
5181e7b
46faf72
bd4989a
985965a
8a0db4c
cf12ace
0e927fb
ee88fc0
4ef6aab
7ef57fd
20d4dc6
0907f9d
84e8a4f
77ea230
6ff9321
82124a4
4e9c665
da82d30
33962c2
23baa53
7664e75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -809,8 +809,8 @@ func (p *payloadInfo) free() { | |
} | ||
} | ||
|
||
// Global error for message size exceeding the limit | ||
var ErrMaxMessageSizeExceeded = errors.New("max message size exceeded") | ||
// errMaxMessageSizeExceeded represents an error due to exceeding the maximum message size limit. | ||
var errMaxMessageSizeExceeded = errors.New("max message size exceeded") | ||
|
||
// recvAndDecompress reads a message from the stream, decompressing it if necessary. | ||
// | ||
|
@@ -844,13 +844,13 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM | |
out = mem.BufferSlice{mem.SliceBuffer(uncompressedBuf)} | ||
} | ||
} else { | ||
out, _, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool) | ||
out, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool) | ||
} | ||
if err != nil { | ||
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) | ||
} | ||
|
||
if err == ErrMaxMessageSizeExceeded { | ||
if err == errMaxMessageSizeExceeded { | ||
arjan-bal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
out.Free() | ||
// TODO: Revisit the error code. Currently keep it consistent with java | ||
// implementation. | ||
|
@@ -870,41 +870,43 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM | |
} | ||
|
||
// Using compressor, decompress d, returning data and size. | ||
// If the decompressed data exceeds maxReceiveMessageSize, it returns nil, 0, and an error. | ||
func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, int, error) { | ||
// If the decompressed data exceeds maxReceiveMessageSize, it returns errMaxMessageSizeExceeded. | ||
func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, error) { | ||
dcReader, err := compressor.Decompress(d.Reader()) | ||
if err != nil { | ||
return nil, 0, err | ||
return nil, err | ||
} | ||
|
||
out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)), pool) | ||
if err != nil { | ||
out.Free() | ||
return nil, 0, ErrMaxMessageSizeExceeded | ||
return nil, err | ||
} | ||
if err = checkReceiveMessageOverflow(int64(out.Len()), int64(maxReceiveMessageSize), dcReader); err != nil { | ||
return nil, out.Len() + 1, err | ||
|
||
if doesReceiveMessageOverflow(int64(out.Len()), int64(maxReceiveMessageSize), dcReader) { | ||
arjan-bal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return nil, errMaxMessageSizeExceeded | ||
dfawley marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this just return the (We'd want to make sure this method always returns an error that is appropriate to just return directly from |
||
} | ||
return out, out.Len(), nil | ||
return out, nil | ||
} | ||
|
||
// checkReceiveMessageOverflow checks if the number of bytes read from the stream exceeds | ||
// the maximum receive message size allowed by the client. If the `readBytes` equals | ||
// `maxReceiveMessageSize`, the function attempts to read one more byte from the `dcReader` | ||
// to detect if there's an overflow. | ||
// doesReceiveMessageOverflow checks if the number of bytes read from the stream | ||
// exceeds the maximum receive message size allowed by the client. If the `readBytes` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
|
||
// is greater than or equal to `maxReceiveMessageSize`, the function attempts to read | ||
// one more byte from the `dcReader` to detect if there's an overflow. | ||
// | ||
// If additional data is read, or an error other than `io.EOF` is encountered, the function | ||
// returns an error indicating that the message size has exceeded the permissible limit. | ||
func checkReceiveMessageOverflow(readBytes, maxReceiveMessageSize int64, dcReader io.Reader) error { | ||
// returns `true` to indicate that the message size has exceeded the permissible limit. | ||
// Otherwise, it returns `false` indicating no overflow. | ||
func doesReceiveMessageOverflow(readBytes, maxReceiveMessageSize int64, dcReader io.Reader) bool { | ||
if readBytes < maxReceiveMessageSize { | ||
return nil | ||
return false | ||
} | ||
|
||
b := make([]byte, 1) | ||
if n, err := dcReader.Read(b); n > 0 || err != io.EOF { | ||
return fmt.Errorf("overflow: received message size is larger than the allowed maxReceiveMessageSize (%d bytes)", maxReceiveMessageSize) | ||
return true | ||
} | ||
return nil | ||
return false | ||
} | ||
|
||
type recvCompressor interface { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ package grpc | |
import ( | ||
"bytes" | ||
"compress/gzip" | ||
"errors" | ||
"io" | ||
"math" | ||
"reflect" | ||
|
@@ -326,52 +327,53 @@ func TestDecompress(t *testing.T) { | |
input []byte | ||
maxReceiveMessageSize int | ||
want []byte | ||
error error | ||
wantErr error | ||
}{ | ||
{ | ||
name: "Decompresses successfully with sufficient buffer size", | ||
compressor: c, | ||
input: []byte("decompressed data"), | ||
arjan-bal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
maxReceiveMessageSize: 50, | ||
want: []byte("decompressed data"), | ||
error: nil, | ||
wantErr: nil, | ||
}, | ||
{ | ||
name: "failure, empty receive message", | ||
name: "Fails due to exceeding maxReceiveMessageSize", | ||
compressor: c, | ||
input: []byte{}, | ||
maxReceiveMessageSize: 10, | ||
input: []byte("small message that is too large"), | ||
arjan-bal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
maxReceiveMessageSize: 5, | ||
arjan-bal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
want: nil, | ||
error: nil, | ||
wantErr: errMaxMessageSizeExceeded, | ||
}, | ||
{ | ||
name: "overflow failure, receive message exceeds maxReceiveMessageSize", | ||
name: "Decompresses to exactly maxReceiveMessageSize", | ||
compressor: c, | ||
input: []byte("small message"), | ||
maxReceiveMessageSize: 5, | ||
want: nil, | ||
error: ErrMaxMessageSizeExceeded, | ||
input: []byte("exact size message"), | ||
maxReceiveMessageSize: len("exact size message"), | ||
want: []byte("exact size message"), | ||
wantErr: nil, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
arjan-bal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
t.Run(tt.name, func(t *testing.T) { | ||
compressedMsg := compressInput(tt.input) | ||
arjan-bal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
output, numSliceInBuf, err := decompress(tt.compressor, compressedMsg, tt.maxReceiveMessageSize, mem.DefaultBufferPool()) | ||
var wantMsg mem.BufferSlice | ||
if tt.want != nil { | ||
wantMsg = mem.BufferSlice{mem.NewBuffer(&tt.want, nil)} | ||
} | ||
if tt.error != nil && err == nil { | ||
t.Fatalf("decompress() error, got err=%v, want err=%v", err, tt.error) | ||
output, err := decompress(tt.compressor, compressedMsg, tt.maxReceiveMessageSize, mem.DefaultBufferPool()) | ||
|
||
if tt.wantErr != nil { | ||
if !errors.Is(err, tt.wantErr) { | ||
arjan-bal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
t.Fatalf("decompress() error = %v, wantErr = %v", err, tt.wantErr) | ||
} | ||
return | ||
} | ||
if tt.error == nil && numSliceInBuf != wantMsg.Len() { | ||
t.Fatalf("decompress() number of slices mismatch, got = %d, want = %d", numSliceInBuf, wantMsg.Len()) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nix new line |
||
if err != nil { | ||
t.Fatalf("decompress() unexpected error = %v", err) | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nix new line |
||
if diff := cmp.Diff(tt.want, output.Materialize()); diff != "" { | ||
t.Fatalf("decompress() mismatch (-want +got):\n%s", diff) | ||
} | ||
|
||
}) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is unnecessary.