From d70c9020868cbda7d51fc5662026626f565b06cb Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 16 Sep 2024 17:26:34 +0200 Subject: [PATCH] Align filename prefix splitting with WebDataset library (#7151) * Align filename prefix splitting with WebDataset library * Fix import --- .../packaged_modules/webdataset/webdataset.py | 52 ++++++++++++------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py index 864c81d409c..c04f3ba4639 100644 --- a/src/datasets/packaged_modules/webdataset/webdataset.py +++ b/src/datasets/packaged_modules/webdataset/webdataset.py @@ -1,5 +1,6 @@ import io import json +import re from itertools import islice from typing import Any, Callable, Dict, List @@ -28,25 +29,26 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator): fs: fsspec.AbstractFileSystem = fsspec.filesystem("memory") streaming_download_manager = datasets.StreamingDownloadManager() for filename, f in tar_iterator: - if "." in filename: - example_key, field_name = filename.split(".", 1) - if current_example and current_example["__key__"] != example_key: - yield current_example - current_example = {} - current_example["__key__"] = example_key - current_example["__url__"] = tar_path - current_example[field_name.lower()] = f.read() - if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL: - fs.write_bytes(filename, current_example[field_name.lower()]) - extracted_file_path = streaming_download_manager.extract(f"memory://{filename}") - with fsspec.open(extracted_file_path) as f: - current_example[field_name.lower()] = f.read() - fs.delete(filename) - data_extension = xbasename(extracted_file_path).split(".")[-1] - else: - data_extension = field_name.split(".")[-1] - if data_extension in cls.DECODERS: - current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name]) + example_key, field_name = base_plus_ext(filename) + if example_key is None: + continue + if current_example and current_example["__key__"] != example_key: + yield current_example + current_example = {} + current_example["__key__"] = example_key + current_example["__url__"] = tar_path + current_example[field_name.lower()] = f.read() + if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL: + fs.write_bytes(filename, current_example[field_name.lower()]) + extracted_file_path = streaming_download_manager.extract(f"memory://{filename}") + with fsspec.open(extracted_file_path) as f: + current_example[field_name.lower()] = f.read() + fs.delete(filename) + data_extension = xbasename(extracted_file_path).split(".")[-1] + else: + data_extension = field_name.split(".")[-1] + if data_extension in cls.DECODERS: + current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name]) if current_example: yield current_example @@ -121,6 +123,18 @@ def _generate_examples(self, tar_paths, tar_iterators): yield f"{tar_idx}_{example_idx}", example +# Source: https://github.com/webdataset/webdataset/blob/87bd5aa41602d57f070f65a670893ee625702f2f/webdataset/tariterators.py#L25 +def base_plus_ext(path): + """Split off all file extensions. + + Returns base, allext. + """ + match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path) + if not match: + return None, None + return match.group(1), match.group(2) + + # Obtained with: # ``` # import PIL.Image