Skip to content

Commit

Permalink
Close gzipped files properly (#6893)
Browse files Browse the repository at this point in the history
* Update compression.py

* Update compression.py

* Update compression.py

* Update compression.py

* Update compression.py
  • Loading branch information
lhoestq authored May 13, 2024
1 parent 871eabc commit ddb6a28
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/datasets/filesystems/compression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Optional

import fsspec
Expand Down Expand Up @@ -32,9 +33,11 @@ def __init__(
target_options (:obj:``dict``, optional): Kwargs passed when instantiating the target FS.
"""
super().__init__(self, **kwargs)
self.fo = fo.__fspath__() if hasattr(fo, "__fspath__") else fo
# always open as "rb" since fsspec can then use the TextIOWrapper to make it work for "r" mode
self.file = fsspec.open(
fo,
self._open_with_fsspec = partial(
fsspec.open,
self.fo,
mode="rb",
protocol=target_protocol,
compression=self.compression,
Expand All @@ -45,7 +48,7 @@ def __init__(
},
**(target_options or {}),
)
self.compressed_name = os.path.basename(self.file.path.split("::")[0])
self.compressed_name = os.path.basename(self.fo.split("::")[0])
self.uncompressed_name = (
self.compressed_name[: self.compressed_name.rindex(".")]
if "." in self.compressed_name
Expand All @@ -60,11 +63,12 @@ def _strip_protocol(cls, path):

def _get_dirs(self):
if self.dir_cache is None:
f = {**self.file.fs.info(self.file.path), "name": self.uncompressed_name}
f = {**self._open_with_fsspec().fs.info(self.fo), "name": self.uncompressed_name}
self.dir_cache = {f["name"]: f}

def cat(self, path: str):
return self.file.open().read()
with self._open_with_fsspec().open() as f:
return f.read()

def _open(
self,
Expand All @@ -77,8 +81,8 @@ def _open(
):
path = self._strip_protocol(path)
if mode != "rb":
raise ValueError(f"Tried to read with mode {mode} on file {self.file.path} opened with mode 'rb'")
return self.file.open()
raise ValueError(f"Tried to read with mode {mode} on file {self.fo} opened with mode 'rb'")
return self._open_with_fsspec().open()


class Bz2FileSystem(BaseCompressedFileFileSystem):
Expand Down

0 comments on commit ddb6a28

Please sign in to comment.