diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 231dec02..cb0d124a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -1,5 +1,8 @@ name: Test on: [push, pull_request] +concurrency: # https://stackoverflow.com/questions/66335225#comment133398800_72408109 + group: ${{ github.workflow }}-${{ github.ref || github.run_id }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} jobs: linters: runs-on: ubuntu-latest @@ -12,10 +15,10 @@ jobs: python-version: "3.11" - name: Update pip - run: python -m pip install -U pip + run: pip install -U pip - name: Install dependencies - run: python -m pip install flake8 + run: pip install flake8 - name: Run flake8 linter (source) run: flake8 --show-source smart_open @@ -26,15 +29,17 @@ jobs: strategy: matrix: include: - - {python: '3.8', os: ubuntu-20.04} - - {python: '3.9', os: ubuntu-20.04} - - {python: '3.10', os: ubuntu-20.04} - - {python: '3.11', os: ubuntu-20.04} - - - {python: '3.8', os: windows-2019} - - {python: '3.9', os: windows-2019} - - {python: '3.10', os: windows-2019} - - {python: '3.11', os: windows-2019} + - {python-version: '3.8', os: ubuntu-20.04} + - {python-version: '3.9', os: ubuntu-20.04} + - {python-version: '3.10', os: ubuntu-20.04} + - {python-version: '3.11', os: ubuntu-20.04} + - {python-version: '3.12', os: ubuntu-20.04} + + - {python-version: '3.8', os: windows-2019} + - {python-version: '3.9', os: windows-2019} + - {python-version: '3.10', os: windows-2019} + - {python-version: '3.11', os: windows-2019} + - {python-version: '3.12', os: windows-2019} steps: - uses: actions/checkout@v2 @@ -43,13 +48,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Update pip - run: python -m pip install -U pip - - # - # https://askubuntu.com/questions/1428181/module-lib-has-no-attribute-x509-v-flag-cb-issuer-check - # - - name: Upgrade PyOpenSSL - run: python -m pip install pyOpenSSL --upgrade + run: pip install -U pip - name: Install smart_open without dependencies run: pip install -e . @@ -69,19 +68,20 @@ jobs: strategy: matrix: include: - - {python: '3.8', os: ubuntu-20.04} - - {python: '3.9', os: ubuntu-20.04} - - {python: '3.10', os: ubuntu-20.04} - - {python: '3.11', os: ubuntu-20.04} + - {python-version: '3.8', os: ubuntu-20.04} + - {python-version: '3.9', os: ubuntu-20.04} + - {python-version: '3.10', os: ubuntu-20.04} + - {python-version: '3.11', os: ubuntu-20.04} + - {python-version: '3.12', os: ubuntu-20.04} # # Some of the doctests don't pass on Windows because of Windows-specific # character encoding issues. # - # - {python: '3.7', os: windows-2019} - # - {python: '3.8', os: windows-2019} - # - {python: '3.9', os: windows-2019} - # - {python: '3.10', os: windows-2019} + # - {python-version: '3.7', os: windows-2019} + # - {python-version: '3.8', os: windows-2019} + # - {python-version: '3.9', os: windows-2019} + # - {python-version: '3.10', os: windows-2019} steps: - uses: actions/checkout@v2 @@ -91,10 +91,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Update pip - run: python -m pip install -U pip - - - name: Upgrade PyOpenSSL - run: python -m pip install pyOpenSSL --upgrade + run: pip install -U pip - name: Install smart_open and its dependencies run: pip install -e .[test] @@ -111,17 +108,18 @@ jobs: strategy: matrix: include: - - {python: '3.8', os: ubuntu-20.04} - - {python: '3.9', os: ubuntu-20.04} - - {python: '3.10', os: ubuntu-20.04} - - {python: '3.11', os: ubuntu-20.04} + - {python-version: '3.8', os: ubuntu-20.04} + - {python-version: '3.9', os: ubuntu-20.04} + - {python-version: '3.10', os: ubuntu-20.04} + - {python-version: '3.11', os: ubuntu-20.04} + - {python-version: '3.12', os: ubuntu-20.04} # Not sure why we exclude these, perhaps for historical reasons? # - # - {python: '3.7', os: windows-2019} - # - {python: '3.8', os: windows-2019} - # - {python: '3.9', os: windows-2019} - # - {python: '3.10', os: windows-2019} + # - {python-version: '3.7', os: windows-2019} + # - {python-version: '3.8', os: windows-2019} + # - {python-version: '3.9', os: windows-2019} + # - {python-version: '3.10', os: windows-2019} steps: - uses: actions/checkout@v2 @@ -131,20 +129,17 @@ jobs: python-version: ${{ matrix.python-version }} - name: Update pip - run: python -m pip install -U pip - - - name: Upgrade PyOpenSSL - run: python -m pip install pyOpenSSL --upgrade - - - run: python -m pip install numpy + run: pip install -U pip - name: Install smart_open and its dependencies run: pip install -e .[test] - run: bash ci_helpers/helpers.sh enable_moto_server if: ${{ matrix.moto_server }} - - - run: | + + - name: Start vsftpd + timeout-minutes: 2 + run: | sudo apt-get install vsftpd sudo bash ci_helpers/helpers.sh create_ftp_ftps_servers @@ -156,7 +151,7 @@ jobs: - run: bash ci_helpers/helpers.sh disable_moto_server if: ${{ matrix.moto_server }} - + - run: sudo bash ci_helpers/helpers.sh delete_ftp_ftps_servers benchmarks: @@ -165,15 +160,16 @@ jobs: strategy: matrix: include: - - {python: '3.8', os: ubuntu-20.04} - - {python: '3.9', os: ubuntu-20.04} - - {python: '3.10', os: ubuntu-20.04} - - {python: '3.11', os: ubuntu-20.04} + - {python-version: '3.8', os: ubuntu-20.04} + - {python-version: '3.9', os: ubuntu-20.04} + - {python-version: '3.10', os: ubuntu-20.04} + - {python-version: '3.11', os: ubuntu-20.04} + - {python-version: '3.12', os: ubuntu-20.04} - # - {python: '3.7', os: windows-2019} - # - {python: '3.8', os: windows-2019} - # - {python: '3.9', os: windows-2019} - # - {python: '3.10', os: windows-2019} + # - {python-version: '3.7', os: windows-2019} + # - {python-version: '3.8', os: windows-2019} + # - {python-version: '3.9', os: windows-2019} + # - {python-version: '3.10', os: windows-2019} steps: - uses: actions/checkout@v2 @@ -183,16 +179,11 @@ jobs: python-version: ${{ matrix.python-version }} - name: Update pip - run: python -m pip install -U pip - - - name: Upgrade PyOpenSSL - run: python -m pip install pyOpenSSL --upgrade + run: pip install -U pip - name: Install smart_open and its dependencies run: pip install -e .[test] - - run: pip install awscli pytest_benchmark - - name: Run benchmarks run: python ci_helpers/run_benchmarks.py env: diff --git a/CHANGELOG.md b/CHANGELOG.md index 85fc7d31..53bef7c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,13 @@ +# 7.0.5, 2024-10-04 + +- Fix zstd compression in ab mode (PR [#833](https://github.com/piskvorky/smart_open/pull/833), [@ddelange](https://github.com/ddelange)) +- Fix close function not neing able to upload a compressed S3 (PR [#838](https://github.com/piskvorky/smart_open/pull/838), [@jbarragan-bridge](https://github.com/jbarragan-bridge)) +- Fix test_http.request_callback (PR [#828](https://github.com/piskvorky/smart_open/pull/828), [@ddelange](https://github.com/ddelange)) +- Update readline logic for azure to match s3 (PR [#826](https://github.com/piskvorky/smart_open/pull/826), [@quantumfusion](https://github.com/quantumfusion)) +- Make http handler take an optional requests.Session (PR [#825](https://github.com/piskvorky/smart_open/pull/825), [@arondaniel](https://github.com/arondaniel)) +- Ensure no side effects on SinglepartWriter exception (PR [#820](https://github.com/piskvorky/smart_open/pull/820), [@donsokolone](https://github.com/donsokolone)) +- Add support for `get_blob_kwargs` to GCS blob read operations (PR [#817](https://github.com/piskvorky/smart_open/pull/817), [@thejcannon](https://github.com/thejcannon)) + # 7.0.4, 2024-03-26 * Fix wb mode with zstd compression (PR [#815](https://github.com/piskvorky/smart_open/pull/815), [@djudd](https://github.com/djudd)) @@ -482,3 +492,4 @@ The old `smart_open.smart_open` function is deprecated, but continues to work as - support for multistream bzip files (PR #9, @pombredanne) - introduce this CHANGELOG + diff --git a/README.rst b/README.rst index 2b1775ce..c7060131 100644 --- a/README.rst +++ b/README.rst @@ -411,6 +411,8 @@ GCS Advanced Usage Additional keyword arguments can be propagated to the GCS open method (`docs `__), which is used by ``smart_open`` under the hood, using the ``blob_open_kwargs`` transport parameter. +Additionally keyword arguments can be propagated to the GCS ``get_blob`` method (`docs `__) when in a read-mode, using the ``get_blob_kwargs`` transport parameter. + Additional blob properties (`docs `__) can be set before an upload, as long as they are not read-only, using the ``blob_properties`` transport parameter. .. code-block:: python @@ -507,4 +509,3 @@ issues or pull requests there. Suggestions, pull requests and improvements welco ``smart_open`` is open source software released under the `MIT license `_. Copyright (c) 2015-now `Radim Řehůřek `_. - diff --git a/ci_helpers/helpers.sh b/ci_helpers/helpers.sh index 4533b337..e0accc7a 100644 --- a/ci_helpers/helpers.sh +++ b/ci_helpers/helpers.sh @@ -20,6 +20,9 @@ create_ftp_ftps_servers(){ mkdir $home_dir useradd -p $(echo $pass | openssl passwd -1 -stdin) -d $home_dir $user chown $user:$user $home_dir + openssl req -x509 -nodes -new -sha256 -days 10240 -newkey rsa:2048 -keyout /etc/vsftpd.key -out /etc/vsftpd.pem -subj "/C=ZA/CN=localhost" + chmod 755 /etc/vsftpd.key + chmod 755 /etc/vsftpd.pem server_setup=''' listen=YES @@ -32,6 +35,8 @@ chroot_local_user=YES allow_writeable_chroot=YES''' additional_ssl_setup=''' +rsa_cert_file=/etc/vsftpd.pem +rsa_private_key_file=/etc/vsftpd.key ssl_enable=YES allow_anon_ssl=NO force_local_data_ssl=NO diff --git a/integration-tests/test_ftp.py b/integration-tests/test_ftp.py index 000faae7..94b4e037 100644 --- a/integration-tests/test_ftp.py +++ b/integration-tests/test_ftp.py @@ -1,6 +1,11 @@ from __future__ import unicode_literals import pytest from smart_open import open +import ssl +from functools import partial + +# localhost has self-signed cert, see ci_helpers/helpers.sh:create_ftp_ftps_servers +ssl.create_default_context = partial(ssl.create_default_context, cafile="/etc/vsftpd.pem") @pytest.fixture(params=[("ftp", 21), ("ftps", 90)]) @@ -81,4 +86,4 @@ def test_line_endings_binary(server_info): with open(f"{server_type}://user:123@localhost:{port_num}/file4", "rb") as f: for line in f: - assert B_CLRF in line \ No newline at end of file + assert B_CLRF in line diff --git a/setup.py b/setup.py index a9a4fc53..9e738bea 100644 --- a/setup.py +++ b/setup.py @@ -47,9 +47,12 @@ def read(fname): tests_require = all_deps + [ 'moto[server]', 'responses', - 'boto3', 'pytest', 'pytest-rerunfailures', + 'pytest_benchmark', + 'awscli', + 'pyopenssl', + 'numpy', ] setup( diff --git a/smart_open/azure.py b/smart_open/azure.py index 5cac221b..1c991f05 100644 --- a/smart_open/azure.py +++ b/smart_open/azure.py @@ -325,24 +325,22 @@ def readline(self, limit=-1): """Read up to and including the next newline. Returns the bytes read.""" if limit != -1: raise NotImplementedError('limits other than -1 not implemented yet') - the_line = io.BytesIO() + + # + # A single line may span multiple buffers. + # + line = io.BytesIO() while not (self._position == self._size and len(self._current_part) == 0): - # - # In the worst case, we're reading the unread part of self._current_part - # twice here, once in the if condition and once when calling index. - # - # This is sub-optimal, but better than the alternative: wrapping - # .index in a try..except, because that is slower. - # - remaining_buffer = self._current_part.peek() - if self._line_terminator in remaining_buffer: - next_newline = remaining_buffer.index(self._line_terminator) - the_line.write(self._read_from_buffer(next_newline + 1)) + line_part = self._current_part.readline(self._line_terminator) + line.write(line_part) + self._position += len(line_part) + + if line_part.endswith(self._line_terminator): break else: - the_line.write(self._read_from_buffer()) self._fill_buffer() - return the_line.getvalue() + + return line.getvalue() # # Internal methods. diff --git a/smart_open/ftp.py b/smart_open/ftp.py index 40d54ca3..a7212ecd 100644 --- a/smart_open/ftp.py +++ b/smart_open/ftp.py @@ -86,7 +86,7 @@ def convert_transport_params_to_args(transport_params): def _connect(hostname, username, port, password, secure_connection, transport_params): kwargs = convert_transport_params_to_args(transport_params) if secure_connection: - ssl_context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) + ssl_context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) ftp = FTP_TLS(context=ssl_context, **kwargs) else: ftp = FTP(**kwargs) diff --git a/smart_open/gcs.py b/smart_open/gcs.py index c5ab29be..0ae34307 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -59,6 +59,7 @@ def open( buffer_size=None, min_part_size=_DEFAULT_MIN_PART_SIZE, client=None, # type: google.cloud.storage.Client + get_blob_kwargs=None, blob_properties=None, blob_open_kwargs=None, ): @@ -78,6 +79,9 @@ def open( The minimum part size for multipart uploads. For writing only. client: google.cloud.storage.Client, optional The GCS client to use when working with google-cloud-storage. + get_blob_kwargs: dict, optional + Additional keyword arguments to propagate to the bucket.get_blob + method of the google-cloud-storage library. For reading only. blob_properties: dict, optional Set properties on blob before writing. For writing only. blob_open_kwargs: dict, optional @@ -95,6 +99,7 @@ def open( _blob = Reader(bucket=bucket_id, key=blob_id, client=client, + get_blob_kwargs=get_blob_kwargs, blob_open_kwargs=blob_open_kwargs) elif mode in (constants.WRITE_BINARY, 'w', 'wt'): @@ -116,8 +121,11 @@ def Reader(bucket, buffer_size=None, line_terminator=None, client=None, + get_blob_kwargs=None, blob_open_kwargs=None): + if get_blob_kwargs is None: + get_blob_kwargs = {} if blob_open_kwargs is None: blob_open_kwargs = {} if client is None: @@ -128,7 +136,7 @@ def Reader(bucket, warn_deprecated('line_terminator') bkt = client.bucket(bucket) - blob = bkt.get_blob(key) + blob = bkt.get_blob(key, **get_blob_kwargs) if blob is None: raise google.cloud.exceptions.NotFound(f'blob {key} not found in {bucket}') diff --git a/smart_open/http.py b/smart_open/http.py index 7bbbe6f4..438ae0f4 100644 --- a/smart_open/http.py +++ b/smart_open/http.py @@ -50,7 +50,7 @@ def open_uri(uri, mode, transport_params): def open(uri, mode, kerberos=False, user=None, password=None, cert=None, - headers=None, timeout=None, buffer_size=DEFAULT_BUFFER_SIZE): + headers=None, timeout=None, session=None, buffer_size=DEFAULT_BUFFER_SIZE): """Implement streamed reader from a web site. Supports Kerberos and Basic HTTP authentication. @@ -73,6 +73,9 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None, Any headers to send in the request. If ``None``, the default headers are sent: ``{'Accept-Encoding': 'identity'}``. To use no headers at all, set this variable to an empty dict, ``{}``. + session: object, optional + The requests Session object to use with http get requests. + Can be used for OAuth2 clients. buffer_size: int, optional The buffer size to use when performing I/O. @@ -86,7 +89,7 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None, fobj = SeekableBufferedInputBase( uri, mode, buffer_size=buffer_size, kerberos=kerberos, user=user, password=password, cert=cert, - headers=headers, timeout=timeout, + headers=headers, session=session, timeout=timeout, ) fobj.name = os.path.basename(urllib.parse.urlparse(uri).path) return fobj @@ -97,7 +100,10 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None, class BufferedInputBase(io.BufferedIOBase): def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, kerberos=False, user=None, password=None, cert=None, - headers=None, timeout=None): + headers=None, session=None, timeout=None): + + self.session = session or requests + if kerberos: import requests_kerberos auth = requests_kerberos.HTTPKerberosAuth() @@ -116,7 +122,14 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, self.timeout = timeout - self.response = requests.get( + self.response = session.get( + url, + auth=auth, + cert=cert, + stream=True, + headers=self.headers, + timeout=self.timeout, + ) if session is not None else requests.get( url, auth=auth, cert=cert, @@ -217,7 +230,7 @@ class SeekableBufferedInputBase(BufferedInputBase): def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, kerberos=False, user=None, password=None, cert=None, - headers=None, timeout=None): + headers=None, session=None, timeout=None): """ If Kerberos is True, will attempt to use the local Kerberos credentials. If cert is set, will try to use a client certificate @@ -227,6 +240,8 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, """ self.url = url + self.session = session or requests + if kerberos: import requests_kerberos self.auth = requests_kerberos.HTTPKerberosAuth() @@ -332,7 +347,7 @@ def _partial_request(self, start_pos=None): if start_pos is not None: self.headers.update({"range": smart_open.utils.make_range_string(start_pos)}) - response = requests.get( + response = self.session.get( self.url, auth=self.auth, stream=True, diff --git a/smart_open/s3.py b/smart_open/s3.py index 09433ad7..60ae2a99 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -1209,8 +1209,8 @@ def write(self, b): return length def terminate(self): - """Nothing to cancel in single-part uploads.""" - return + self._buf = None + logger.debug('%s: terminated singlepart upload', self) # # Internal methods. diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py index d63b20cf..d8a4979d 100644 --- a/smart_open/tests/test_gcs.py +++ b/smart_open/tests/test_gcs.py @@ -45,6 +45,8 @@ def __init__(self, client, name=None): # self.client.register_bucket(self) + self.get_blob = mock.Mock(side_effect=self._get_blob) + def blob(self, blob_id, **kwargs): return self.blobs.get(blob_id, FakeBlob(blob_id, self, **kwargs)) @@ -57,7 +59,7 @@ def delete(self): def exists(self): return self._exists - def get_blob(self, blob_id): + def _get_blob(self, blob_id, **kwargs): try: return self.blobs[blob_id] except KeyError as e: @@ -300,6 +302,15 @@ def test_property_passthrough(self): for k, v in blob_properties.items(): self.assertEqual(getattr(b, k), v) + def test_get_blob_kwargs_passthrough(self): + get_blob_kwargs = {'generation': '1111111111111111'} + + with self.assertRaises(google.cloud.exceptions.NotFound): + smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME, get_blob_kwargs=get_blob_kwargs) + + self.client.bucket(BUCKET_NAME) \ + .get_blob.assert_called_once_with(BLOB_NAME, **get_blob_kwargs) + def test_default_open_kwargs(self): smart_open.gcs.Writer(BUCKET_NAME, BLOB_NAME) diff --git a/smart_open/tests/test_http.py b/smart_open/tests/test_http.py index f4f8338a..a70f86d7 100644 --- a/smart_open/tests/test_http.py +++ b/smart_open/tests/test_http.py @@ -15,30 +15,28 @@ import smart_open.http import smart_open.s3 import smart_open.constants - +import requests BYTES = b'i tried so hard and got so far but in the end it doesn\'t even matter' URL = 'http://localhost' HTTPS_URL = 'https://localhost' HEADERS = { - 'Content-Length': str(len(BYTES)), 'Accept-Ranges': 'bytes', } def request_callback(request, headers=HEADERS, data=BYTES): - try: - range_string = request.headers['range'] - except KeyError: - return (200, headers, data) + headers = headers.copy() + range_string = request.headers.get('range', 'bytes=0-') - start, end = range_string.replace('bytes=', '').split('-', 1) + start, end = range_string.replace('bytes=', '', 1).split('-', 1) start = int(start) - if end: - end = int(end) - else: - end = len(data) - return (200, headers, data[start:end]) + end = int(end) if end else len(data) + + data = data[start:end] + headers['Content-Length'] = str(len(data)) + + return (200, headers, data) @unittest.skipIf(os.environ.get('TRAVIS'), 'This test does not work on TravisCI for some reason') @@ -161,6 +159,15 @@ def test_timeout_attribute(self): assert hasattr(reader, 'timeout') assert reader.timeout == timeout + @responses.activate + def test_session_attribute(self): + session = requests.Session() + responses.add_callback(responses.GET, URL, callback=request_callback) + reader = smart_open.open(URL, "rb", transport_params={'session': session}) + assert hasattr(reader, 'session') + assert reader.session == session + assert reader.read() == BYTES + @responses.activate def test_seek_implicitly_enabled(numbytes=10): diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py index ff44ad4f..78d32c0f 100644 --- a/smart_open/tests/test_s3.py +++ b/smart_open/tests/test_s3.py @@ -634,6 +634,43 @@ def test_writebuffer(self): assert actual == contents + def test_write_gz_using_context_manager(self): + """Does s3 multipart upload create a compressed file using context manager?""" + contents = b'get ready for a surprise' + with smart_open.open( + f's3://{BUCKET_NAME}/{WRITE_KEY_NAME}.gz', + mode="wb", + transport_params={ + "multipart_upload": True, + "min_part_size": 10, + } + ) as fout: + fout.write(contents) + + with smart_open.open(f's3://{BUCKET_NAME}/{WRITE_KEY_NAME}.gz', 'rb') as fin: + actual = fin.read() + + assert actual == contents + + def test_write_gz_not_using_context_manager(self): + """Does s3 multipart upload create a compressed file not using context manager but close()?""" + contents = b'get ready for a surprise' + fout = smart_open.open( + f's3://{BUCKET_NAME}/{WRITE_KEY_NAME}.gz', + mode="wb", + transport_params={ + "multipart_upload": True, + "min_part_size": 10, + } + ) + fout.write(contents) + fout.close() + + with smart_open.open(f's3://{BUCKET_NAME}/{WRITE_KEY_NAME}.gz', 'rb') as fin: + actual = fin.read() + + assert actual == contents + def test_write_gz_with_error(self): """Does s3 multipart upload abort for a failed compressed file upload?""" with self.assertRaises(ValueError): @@ -795,6 +832,26 @@ def test_str(self): with smart_open.s3.open(BUCKET_NAME, 'key', 'wb', multipart_upload=False) as fout: assert str(fout) == "smart_open.s3.SinglepartWriter('test-smartopen', 'key')" + def test_ensure_no_side_effects_on_exception(self): + class WriteError(Exception): + pass + + s3_resource = _resource("s3") + obj = s3_resource.Object(BUCKET_NAME, KEY_NAME) + + # wrap in closure to ease writer dereferencing + def _run(): + with smart_open.s3.open(BUCKET_NAME, obj.key, "wb", multipart_upload=False) as fout: + fout.write(b"this should not be written") + raise WriteError + + try: + _run() + except WriteError: + pass + finally: + self.assertRaises(s3_resource.meta.client.exceptions.NoSuchKey, obj.get) + ARBITRARY_CLIENT_ERROR = botocore.client.ClientError(error_response={}, operation_name='bar') diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index e3727271..789f44c6 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -607,6 +607,22 @@ def test_atplus(self): self.assertEqual(text, SAMPLE_TEXT * 2) +class CompressionRealFileSystemTests(RealFileSystemTests): + """Same as RealFileSystemTests but with a compressed file.""" + + def setUp(self): + with named_temporary_file(prefix='test', suffix='.zst', delete=False) as fout: + self.temp_file = fout.name + with smart_open.open(self.temp_file, 'wb') as fout: + fout.write(SAMPLE_BYTES) + + def test_aplus(self): + pass # transparent (de)compression unsupported for mode 'ab+' + + def test_atplus(self): + pass # transparent (de)compression unsupported for mode 'ab+' + + # # What exactly to patch here differs on _how_ we're opening the file. # See the _shortcut_open function for details. @@ -1441,7 +1457,8 @@ def gzip_compress(data, filename=None): buf = io.BytesIO() buf.name = filename with mock.patch('time.time', _MOCK_TIME): - gzip.GzipFile(fileobj=buf, mode='w').write(data) + with gzip.GzipFile(fileobj=buf, mode='w') as gz: + gz.write(data) return buf.getvalue() diff --git a/smart_open/utils.py b/smart_open/utils.py index 2be57d19..efbb9374 100644 --- a/smart_open/utils.py +++ b/smart_open/utils.py @@ -208,6 +208,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): class FileLikeProxy(wrapt.ObjectProxy): + __inner = ... # initialized before wrapt disallows __setattr__ on certain objects + def __init__(self, outer, inner): super().__init__(outer) self.__inner = inner @@ -221,3 +223,10 @@ def __exit__(self, *args, **kwargs): def __next__(self): return self.__wrapped__.__next__() + + def close(self): + try: + return self.__wrapped__.close() + finally: + if self.__inner != self.__wrapped__: # Don't close again if inner and wrapped are the same + self.__inner.close() diff --git a/smart_open/version.py b/smart_open/version.py index 661ddc24..78d54a1c 100644 --- a/smart_open/version.py +++ b/smart_open/version.py @@ -1,4 +1,4 @@ -__version__ = '7.0.4' +__version__ = '7.0.5' if __name__ == '__main__':