Skip to content

Commit

Permalink
Align filename prefix splitting with WebDataset library (huggingface#…
Browse files Browse the repository at this point in the history
…7151)

* Align filename prefix splitting with WebDataset library

* Fix import
  • Loading branch information
albertvillanova authored Sep 16, 2024
1 parent 43b1fe1 commit d70c902
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions src/datasets/packaged_modules/webdataset/webdataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import json
import re
from itertools import islice
from typing import Any, Callable, Dict, List

Expand Down Expand Up @@ -28,25 +29,26 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator):
fs: fsspec.AbstractFileSystem = fsspec.filesystem("memory")
streaming_download_manager = datasets.StreamingDownloadManager()
for filename, f in tar_iterator:
if "." in filename:
example_key, field_name = filename.split(".", 1)
if current_example and current_example["__key__"] != example_key:
yield current_example
current_example = {}
current_example["__key__"] = example_key
current_example["__url__"] = tar_path
current_example[field_name.lower()] = f.read()
if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
fs.write_bytes(filename, current_example[field_name.lower()])
extracted_file_path = streaming_download_manager.extract(f"memory://{filename}")
with fsspec.open(extracted_file_path) as f:
current_example[field_name.lower()] = f.read()
fs.delete(filename)
data_extension = xbasename(extracted_file_path).split(".")[-1]
else:
data_extension = field_name.split(".")[-1]
if data_extension in cls.DECODERS:
current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name])
example_key, field_name = base_plus_ext(filename)
if example_key is None:
continue
if current_example and current_example["__key__"] != example_key:
yield current_example
current_example = {}
current_example["__key__"] = example_key
current_example["__url__"] = tar_path
current_example[field_name.lower()] = f.read()
if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
fs.write_bytes(filename, current_example[field_name.lower()])
extracted_file_path = streaming_download_manager.extract(f"memory://{filename}")
with fsspec.open(extracted_file_path) as f:
current_example[field_name.lower()] = f.read()
fs.delete(filename)
data_extension = xbasename(extracted_file_path).split(".")[-1]
else:
data_extension = field_name.split(".")[-1]
if data_extension in cls.DECODERS:
current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name])
if current_example:
yield current_example

Expand Down Expand Up @@ -121,6 +123,18 @@ def _generate_examples(self, tar_paths, tar_iterators):
yield f"{tar_idx}_{example_idx}", example


# Source: https://github.com/webdataset/webdataset/blob/87bd5aa41602d57f070f65a670893ee625702f2f/webdataset/tariterators.py#L25
def base_plus_ext(path):
"""Split off all file extensions.
Returns base, allext.
"""
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
if not match:
return None, None
return match.group(1), match.group(2)


# Obtained with:
# ```
# import PIL.Image
Expand Down

0 comments on commit d70c902

Please sign in to comment.