-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More features around Beam. #731
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
"""Beam module.""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections.abc import Mapping | ||
import typing | ||
from typing import Any | ||
|
||
from etils import epath | ||
|
||
from mlcroissant._src.datasets import Dataset | ||
|
||
if typing.TYPE_CHECKING: | ||
import apache_beam as beam | ||
|
||
|
||
def ReadFromCroissant( | ||
*, | ||
pipeline: beam.Pipeline, | ||
jsonld: epath.PathLike | Mapping[str, Any], | ||
record_set: str, | ||
mapping: Mapping[str, epath.PathLike] | None = None, | ||
): | ||
"""Returns an Apache Beam reader to generate the dataset using e.g. Spark. | ||
|
||
Example of usage: | ||
|
||
```python | ||
import apache_beam as beam | ||
from apache_beam.options import pipeline_options | ||
import mlcroissant as mlc | ||
|
||
jsonld = "https://huggingface.co/api/datasets/ylecun/mnist/croissant" | ||
|
||
options = pipeline_options.PipelineOptions() | ||
with beam.Pipeline(options=options) as pipeline: | ||
mlc.ReadFromCroissant( | ||
pipeline=pipeline, | ||
jsonld=jsonld, | ||
record_set="default", | ||
) | ||
``` | ||
|
||
Only streamable datasets can be used with Beam. A streamable dataset is a dataset | ||
that can be generated by a linear sequence of operations - without joins for | ||
example. This is the case for Hugging Face datasets. If there are branches, we'd | ||
need a more complex Beam pipeline. | ||
|
||
TODO(https://github.com/mlcommons/croissant/issues/733): handle branches. | ||
|
||
The sharding is done on the filtered files. This is currently optimized for Hugging | ||
Face datasets, so it raises an error if the dataset is not a Hugging Face dataset. | ||
|
||
Args: | ||
pipeline: A Beam pipeline. | ||
jsonld: A JSON object or a path to a Croissant file (URL, str or pathlib.Path). | ||
record_set: The name of the record set to generate. | ||
mapping: Mapping filename->filepath as a Python dict[str, str] to handle manual | ||
downloads. If `document.csv` is the FileObject and you downloaded it to | ||
`~/Downloads/document.csv`, you can specify `mapping={"document.csv": | ||
"~/Downloads/document.csv"}`. | ||
|
||
Returns: | ||
A Beam PCollection with all the records where each element contains a tuple with | ||
a) a global index, and | ||
b) the content of the record. | ||
|
||
Raises: | ||
A ValueError if the dataset is not streamable. | ||
""" | ||
dataset = Dataset(jsonld=jsonld, mapping=mapping) | ||
return dataset.records(record_set).beam_reader(pipeline) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,9 +3,12 @@ | |
from __future__ import annotations | ||
|
||
import collections | ||
from collections.abc import Iterable | ||
import concurrent.futures | ||
import functools | ||
import sys | ||
import typing | ||
from typing import Any | ||
from typing import Any, Generator | ||
|
||
from absl import logging | ||
import networkx as nx | ||
|
@@ -22,6 +25,8 @@ | |
if typing.TYPE_CHECKING: | ||
import apache_beam as beam | ||
|
||
ElementWithIndex = tuple[int, Any] | ||
|
||
|
||
def execute_downloads(operations: Operations): | ||
"""Executes all the downloads in the graph of operations.""" | ||
|
@@ -137,7 +142,7 @@ def read_all_files(): | |
def execute_operations_in_beam( | ||
pipeline: beam.Pipeline, record_set: str, operations: Operations | ||
): | ||
"""See beam_reader docstring.""" | ||
"""See ReadFromCroissant docstring.""" | ||
import apache_beam as beam | ||
|
||
list_of_operations = _order_relevant_operations(operations, record_set) | ||
|
@@ -149,12 +154,69 @@ def execute_operations_in_beam( | |
files = operation(files) | ||
if isinstance(operation, FilterFiles): | ||
break | ||
pipeline = pipeline | "Shard by files" >> beam.Create(files) | ||
if not isinstance(files, Iterable): | ||
raise ValueError("Could not filter files.") | ||
files = list(files) # even for large datasets, this can be handled in RAM. | ||
|
||
# We first shard by file and assign a shard_index. | ||
pipeline = pipeline | "Shard by files with index" >> beam.Create(enumerate(files)) | ||
num_shards = len(files) | ||
|
||
# We don't know in advance the number of records per shards. So we just allocate the | ||
# maximum number which is `sys.maxsize // num_shards`. Taking the practical case of | ||
# large evenly distributed datasets on HuggingFace, we can compute the following: | ||
|
||
# num_shards = number of Parquet files per config on Hugging Face < 10 billion files | ||
# max_shard_size ~ 1 billion records per Parquet files | ||
|
||
# So it seems we can run with this trick without too many problems. We still trigger | ||
# a ValueError below if the error arises, and we ask the user to open a bug. A real | ||
# solution to this problem would be to compute the shard_sizes in parallel of | ||
# generating the records. | ||
# TODO(https://github.com/mlcommons/croissant/issues/732): Compute shard_sizes | ||
# explicitly instead of relying on max_shard_size. | ||
max_shard_size = sys.maxsize // num_shards | ||
while queue_of_operations: | ||
operation = queue_of_operations.popleft() | ||
if isinstance(operation, ReadFields): | ||
beam_operation = beam.ParDo(operation) | ||
beam_operation = beam.ParDo( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are 100% sure that only ReadFields operations can be leaves? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! But it's more linked to the fact that ReadFields is a generator. |
||
functools.partial( | ||
_add_global_index, | ||
operation=operation, | ||
max_shard_size=max_shard_size, | ||
) | ||
) | ||
else: | ||
beam_operation = beam.Map(operation) | ||
beam_operation = beam.Map( | ||
functools.partial(_pass_index, operation=operation) | ||
) | ||
pipeline |= beam_operation | ||
return pipeline | ||
|
||
|
||
def _add_global_index( | ||
element_with_index: ElementWithIndex, | ||
operation: Operation, | ||
max_shard_size: int, | ||
) -> Generator[ElementWithIndex, None, None]: | ||
"""Computes the global index given the shard size.""" | ||
shard_index, element = element_with_index | ||
for index_in_shard, result in enumerate(operation(element)): | ||
if index_in_shard >= max_shard_size: | ||
raise ValueError( | ||
"WARNING: This was very unlikely, but it seems we just hit this limit" | ||
" in the code. Find another way to optimize execute_operations_in_beam." | ||
" Please, open a PR on GitHub to make the maintainers aware of this" | ||
" issue. A fix is to compute the actual shard_sizes rather than relying" | ||
" on a heuristic (see comments above in code)." | ||
) | ||
new_index = max_shard_size * shard_index + index_in_shard | ||
yield (new_index, result) | ||
|
||
|
||
def _pass_index( | ||
element_with_index: tuple[int, Any], operation: Operation | ||
) -> ElementWithIndex: | ||
"""Passes the index to the next operation while executing the operation.""" | ||
index, element = element_with_index | ||
return (index, operation(element)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So IIUC we will change this in future work to work with non-streamable datasets (e.g. future versions of HF croissants) -- I feel if this is right, this would deserve a mention here, or a link to an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.