Skip to content

Commit 1875c91

Browse files
authored
Merge pull request aws#102 from mattsb42-aws/dev-24
Fix handling of partial reads
2 parents 7013f40 + 74355b0 commit 1875c91

12 files changed

+262
-42
lines changed

CHANGELOG.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ Changelog
88
Minor
99
-----
1010

11-
* Add support to remove clients from :ref:`KMSMasterKeyProvider` client cache if they fail to connect to endpoint.
11+
* Add support to remove clients from :class:`KMSMasterKeyProvider` client cache if they fail to connect to endpoint.
1212
`#86 <https://github.com/aws/aws-encryption-sdk-python/pull/86>`_
1313
* Add support for SHA384 and SHA512 for use with RSA OAEP wrapping algorithms.
1414
`#56 <https://github.com/aws/aws-encryption-sdk-python/issues/56>`_
15+
* Fix ``streaming_client`` classes to properly interpret short reads in source streams.
16+
`#24 <https://github.com/aws/aws-encryption-sdk-python/issues/24>`_
1517

1618
1.3.7 -- 2018-09-20
1719
===================

setup.cfg

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ branch = True
1111
show_missing = True
1212

1313
[tool:pytest]
14+
log_level = DEBUG
1415
markers =
1516
local: superset of unit and functional (does not require network access)
1617
unit: mark test as a unit test (does not require network access)

src/aws_encryption_sdk/internal/utils/__init__.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from aws_encryption_sdk.internal.str_ops import to_bytes
2424
from aws_encryption_sdk.structures import EncryptedDataKey
2525

26+
from .streams import InsistentReaderBytesIO
27+
2628
_LOGGER = logging.getLogger(__name__)
2729

2830

@@ -132,12 +134,14 @@ def prep_stream_data(data):
132134
133135
:param data: Input data
134136
:returns: Prepared stream
135-
:rtype: io.BytesIO
137+
:rtype: InsistentReaderBytesIO
136138
"""
137139
if isinstance(data, (six.string_types, six.binary_type)):
138-
return io.BytesIO(to_bytes(data))
140+
stream = io.BytesIO(to_bytes(data))
141+
else:
142+
stream = data
139143

140-
return data
144+
return InsistentReaderBytesIO(stream)
141145

142146

143147
def source_data_key_length_check(source_data_key, algorithm):

src/aws_encryption_sdk/internal/utils/streams.py

+41
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Helper stream utility objects for AWS Encryption SDK."""
14+
import io
15+
1416
from wrapt import ObjectProxy
1517

1618
from aws_encryption_sdk.exceptions import ActionNotAllowedError
19+
from aws_encryption_sdk.internal.str_ops import to_bytes
1720

1821

1922
class ROStream(ObjectProxy):
@@ -56,3 +59,41 @@ def read(self, b=None):
5659
data = self.__wrapped__.read(b)
5760
self.__tee.write(data)
5861
return data
62+
63+
64+
class InsistentReaderBytesIO(ObjectProxy):
65+
"""Wrapper around a readable stream that insists on reading exactly the requested
66+
number of bytes. It will keep trying to read bytes from the wrapped stream until
67+
either the requested number of bytes are available or the wrapped stream has
68+
nothing more to return.
69+
70+
:param wrapped: File-like object
71+
"""
72+
73+
def read(self, b=-1):
74+
"""Keep reading from source stream until either the source stream is done
75+
or the requested number of bytes have been obtained.
76+
77+
:param int b: number of bytes to read
78+
:return: All bytes read from wrapped stream
79+
:rtype: bytes
80+
"""
81+
remaining_bytes = b
82+
data = io.BytesIO()
83+
while True:
84+
try:
85+
chunk = to_bytes(self.__wrapped__.read(remaining_bytes))
86+
except ValueError:
87+
if self.__wrapped__.closed:
88+
break
89+
raise
90+
91+
if not chunk:
92+
break
93+
94+
data.write(chunk)
95+
remaining_bytes -= len(chunk)
96+
97+
if remaining_bytes <= 0:
98+
break
99+
return data.getvalue()

src/aws_encryption_sdk/streaming_client.py

+33-14
Original file line numberDiff line numberDiff line change
@@ -202,24 +202,28 @@ def readable(self):
202202
# Open streams are currently always readable.
203203
return not self.closed
204204

205-
def read(self, b=None):
205+
def read(self, b=-1):
206206
"""Returns either the requested number of bytes or the entire stream.
207207
208208
:param int b: Number of bytes to read
209209
:returns: Processed (encrypted or decrypted) bytes from source stream
210210
:rtype: bytes
211211
"""
212212
# Any negative value for b is interpreted as a full read
213-
if b is not None and b < 0:
214-
b = None
213+
# None is also accepted for legacy compatibility
214+
if b is None or b < 0:
215+
b = -1
215216

216217
_LOGGER.debug("Stream read called, requesting %s bytes", b)
217218
output = io.BytesIO()
219+
218220
if not self._message_prepped:
219221
self._prep_message()
222+
220223
if self.closed:
221224
raise ValueError("I/O operation on closed file")
222-
if b:
225+
226+
if b >= 0:
223227
self._read_bytes(b)
224228
output.write(self.output_buffer[:b])
225229
self.output_buffer = self.output_buffer[b:]
@@ -228,6 +232,7 @@ def read(self, b=None):
228232
self._read_bytes(LINE_LENGTH)
229233
output.write(self.output_buffer)
230234
self.output_buffer = b""
235+
231236
self.bytes_read += output.tell()
232237
_LOGGER.debug("Returning %s bytes of %s bytes requested", output.tell(), b)
233238
return output.getvalue()
@@ -511,14 +516,18 @@ def _read_bytes_to_non_framed_body(self, b):
511516
_LOGGER.debug("Closing encryptor after receiving only %s bytes of %s bytes requested", plaintext, b)
512517
self.source_stream.close()
513518
closing = self.encryptor.finalize()
519+
514520
if self.signer is not None:
515521
self.signer.update(closing)
522+
516523
closing += aws_encryption_sdk.internal.formatting.serialize.serialize_non_framed_close(
517524
tag=self.encryptor.tag, signer=self.signer
518525
)
526+
519527
if self.signer is not None:
520528
closing += aws_encryption_sdk.internal.formatting.serialize.serialize_footer(self.signer)
521529
return ciphertext + closing
530+
522531
return ciphertext
523532

524533
def _read_bytes_to_framed_body(self, b):
@@ -530,14 +539,22 @@ def _read_bytes_to_framed_body(self, b):
530539
"""
531540
_LOGGER.debug("collecting %s bytes", b)
532541
_b = b
533-
b = int(math.ceil(b / float(self.config.frame_length)) * self.config.frame_length)
534-
_LOGGER.debug("%s bytes requested; reading %s bytes after normalizing to frame length", _b, b)
542+
543+
if b > 0:
544+
_frames_to_read = math.ceil(b / float(self.config.frame_length))
545+
b = int(_frames_to_read * self.config.frame_length)
546+
_LOGGER.debug("%d bytes requested; reading %d bytes after normalizing to frame length", _b, b)
547+
535548
plaintext = self.source_stream.read(b)
536-
_LOGGER.debug("%s bytes read from source", len(plaintext))
549+
plaintext_length = len(plaintext)
550+
_LOGGER.debug("%d bytes read from source", plaintext_length)
551+
537552
finalize = False
538-
if len(plaintext) < b:
553+
554+
if b < 0 or plaintext_length < b:
539555
_LOGGER.debug("Final plaintext read from source")
540556
finalize = True
557+
541558
output = b""
542559
final_frame_written = False
543560

@@ -583,8 +600,8 @@ def _read_bytes(self, b):
583600
:param int b: Number of bytes to read
584601
:raises NotSupportedError: if content type is not supported
585602
"""
586-
_LOGGER.debug("%s bytes requested from stream with content type: %s", b, self.content_type)
587-
if b <= len(self.output_buffer) or self.source_stream.closed:
603+
_LOGGER.debug("%d bytes requested from stream with content type: %s", b, self.content_type)
604+
if 0 <= b <= len(self.output_buffer) or self.source_stream.closed:
588605
_LOGGER.debug("No need to read from source stream or source stream closed")
589606
return
590607

@@ -776,10 +793,13 @@ def _read_bytes_from_non_framed_body(self, b):
776793
bytes_to_read = self.body_end - self.source_stream.tell()
777794
_LOGGER.debug("%s bytes requested; reading %s bytes", b, bytes_to_read)
778795
ciphertext = self.source_stream.read(bytes_to_read)
796+
779797
if len(self.output_buffer) + len(ciphertext) < self.body_length:
780798
raise SerializationError("Total message body contents less than specified in body description")
799+
781800
if self.verifier is not None:
782801
self.verifier.update(ciphertext)
802+
783803
plaintext = self.decryptor.update(ciphertext)
784804
plaintext += self.decryptor.finalize()
785805
aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag(
@@ -844,10 +864,9 @@ def _read_bytes(self, b):
844864
_LOGGER.debug("Source stream closed")
845865
return
846866

847-
if b <= len(self.output_buffer):
848-
_LOGGER.debug(
849-
"%s bytes requested less than or equal to current output buffer size %s", b, len(self.output_buffer)
850-
)
867+
buffer_length = len(self.output_buffer)
868+
if 0 <= b <= buffer_length:
869+
_LOGGER.debug("%d bytes requested less than or equal to current output buffer size %d", b, buffer_length)
851870
return
852871

853872
if self._header.content_type == ContentType.FRAMED_DATA:

test/functional/test_f_aws_encryption_sdk_client.py

+71
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,74 @@ def test_stream_decryptor_readable():
598598
assert handler.readable()
599599
handler.read()
600600
assert not handler.readable()
601+
602+
603+
def exact_length_plaintext(length):
604+
plaintext = b""
605+
while len(plaintext) < length:
606+
plaintext += VALUES["plaintext_128"]
607+
return plaintext[:length]
608+
609+
610+
class SometimesIncompleteReaderIO(io.BytesIO):
611+
def __init__(self, *args, **kwargs):
612+
self.__read_counter = 0
613+
super(SometimesIncompleteReaderIO, self).__init__(*args, **kwargs)
614+
615+
def read(self, size=-1):
616+
"""Every other read request, return fewer than the requested number of bytes if more than one byte requested."""
617+
self.__read_counter += 1
618+
if size > 1 and self.__read_counter % 2 == 0:
619+
size //= 2
620+
return super(SometimesIncompleteReaderIO, self).read(size)
621+
622+
623+
@pytest.mark.parametrize(
624+
"frame_length",
625+
(
626+
0, # 0: unframed
627+
128, # 128: framed with exact final frame size match
628+
256, # 256: framed with inexact final frame size match
629+
),
630+
)
631+
def test_incomplete_read_stream_cycle(frame_length):
632+
chunk_size = 21 # Will never be an exact match for the frame size
633+
key_provider = fake_kms_key_provider()
634+
635+
plaintext = exact_length_plaintext(384)
636+
ciphertext = b""
637+
cycle_count = 0
638+
with aws_encryption_sdk.stream(
639+
mode="encrypt",
640+
source=SometimesIncompleteReaderIO(plaintext),
641+
key_provider=key_provider,
642+
frame_length=frame_length,
643+
) as encryptor:
644+
while True:
645+
cycle_count += 1
646+
chunk = encryptor.read(chunk_size)
647+
if not chunk:
648+
break
649+
ciphertext += chunk
650+
if cycle_count > len(VALUES["plaintext_128"]):
651+
raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError(
652+
"Unexpected error encrypting message: infinite loop detected."
653+
)
654+
655+
decrypted = b""
656+
cycle_count = 0
657+
with aws_encryption_sdk.stream(
658+
mode="decrypt", source=SometimesIncompleteReaderIO(ciphertext), key_provider=key_provider
659+
) as decryptor:
660+
while True:
661+
cycle_count += 1
662+
chunk = decryptor.read(chunk_size)
663+
if not chunk:
664+
break
665+
decrypted += chunk
666+
if cycle_count > len(VALUES["plaintext_128"]):
667+
raise aws_encryption_sdk.exceptions.AWSEncryptionSDKClientError(
668+
"Unexpected error encrypting message: infinite loop detected."
669+
)
670+
671+
assert ciphertext != decrypted == plaintext

test/unit/test_streaming_client_encryption_stream.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import aws_encryption_sdk.exceptions
2222
from aws_encryption_sdk.internal.defaults import LINE_LENGTH
23+
from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO
2324
from aws_encryption_sdk.key_providers.base import MasterKeyProvider
2425
from aws_encryption_sdk.streaming_client import _ClientConfig, _EncryptionStream
2526

@@ -107,17 +108,19 @@ def test_new_with_params(self):
107108
line_length=io.DEFAULT_BUFFER_SIZE,
108109
source_length=mock_int_sentinel,
109110
)
110-
assert mock_stream.config == MockClientConfig(
111-
source=self.mock_source_stream,
112-
key_provider=self.mock_key_provider,
113-
mock_read_bytes=sentinel.read_bytes,
114-
line_length=io.DEFAULT_BUFFER_SIZE,
115-
source_length=mock_int_sentinel,
116-
)
111+
112+
assert mock_stream.config.source == self.mock_source_stream
113+
assert isinstance(mock_stream.config.source, InsistentReaderBytesIO)
114+
assert mock_stream.config.key_provider is self.mock_key_provider
115+
assert mock_stream.config.mock_read_bytes is sentinel.read_bytes
116+
assert mock_stream.config.line_length == io.DEFAULT_BUFFER_SIZE
117+
assert mock_stream.config.source_length is mock_int_sentinel
118+
117119
assert mock_stream.bytes_read == 0
118120
assert mock_stream.output_buffer == b""
119121
assert not mock_stream._message_prepped
120-
assert mock_stream.source_stream is self.mock_source_stream
122+
assert mock_stream.source_stream == self.mock_source_stream
123+
assert isinstance(mock_stream.source_stream, InsistentReaderBytesIO)
121124
assert mock_stream._stream_length is mock_int_sentinel
122125
assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE
123126

test/unit/test_streaming_client_stream_decryptor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_prep_non_framed(self):
261261
test_decryptor._prep_non_framed()
262262

263263
self.mock_deserialize_non_framed_values.assert_called_once_with(
264-
stream=self.mock_input_stream, header=self.mock_header, verifier=sentinel.verifier
264+
stream=test_decryptor.source_stream, header=self.mock_header, verifier=sentinel.verifier
265265
)
266266
assert test_decryptor.body_length == len(VALUES["data_128"])
267267
self.mock_get_aad_content_string.assert_called_once_with(

0 commit comments

Comments
 (0)