Skip to content

Commit

Permalink
Better public API for ReadFromCroissant.
Browse files Browse the repository at this point in the history
  • Loading branch information
marcenacp committed Sep 4, 2024
1 parent f6bcc3b commit 52dc04b
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 36 deletions.
2 changes: 2 additions & 0 deletions python/mlcroissant/mlcroissant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Defines the public interface to the `mlcroissant` package."""

from mlcroissant._src import torch
from mlcroissant._src.beam import ReadFromCroissant
from mlcroissant._src.core import constants
from mlcroissant._src.core.constants import DataType
from mlcroissant._src.core.constants import EncodingFormat
Expand Down Expand Up @@ -44,6 +45,7 @@
"Organization",
"Person",
"Rdf",
"ReadFromCroissant",
"Records",
"RecordSet",
"Source",
Expand Down
68 changes: 68 additions & 0 deletions python/mlcroissant/mlcroissant/_src/beam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Beam module."""

from __future__ import annotations

from collections.abc import Mapping
import typing
from typing import Any

from etils import epath

from mlcroissant._src.datasets import Dataset

if typing.TYPE_CHECKING:
import apache_beam as beam


def ReadFromCroissant(
*,
pipeline: beam.Pipeline,
jsonld: epath.PathLike | Mapping[str, Any],
record_set: str,
mapping: Mapping[str, epath.PathLike] | None = None,
):
"""Returns an Apache Beam reader to generate the dataset using e.g. Spark.
Example of usage:
```python
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import mlcroissant as mlc
jsonld = "https://huggingface.co/api/datasets/ylecun/mnist/croissant"
pipeline_options = PipelineOptions()
with beam.Pipeline(options=pipeline_options) as pipeline:
ReadFromCroissant(
pipeline=pipeline,
jsonld=jsonld,
record_set="default",
)
```
Only streamable datasets can be used with Beam. A streamable dataset is a dataset
that can be generated by a linear sequence of operations - without joins for
example. This is the case for Hugging Face datasets. If there are branches, we'd
need a more complex Beam pipeline.
The sharding is done on the filtered files. This is currently optimized for Hugging
Face datasets, so it raises an error if the dataset is not a Hugging Face dataset.
Args:
pipeline: A Beam pipeline.
jsonld: A JSON object or a path to a Croissant file (URL, str or pathlib.Path).
record_set: The name of the record set to generate.
mapping: Mapping filename->filepath as a Python dict[str, str] to handle manual
downloads. If `document.csv` is the FileObject and you downloaded it to
`~/Downloads/document.csv`, you can specify `mapping={"document.csv":
"~/Downloads/document.csv"}`.
Returns:
A Beam PCollection with all the records.
Raises:
A ValueError if the dataset is not streamable.
"""
dataset = Dataset(jsonld=jsonld, mapping=mapping)
return dataset.records(record_set).beam_reader(pipeline)
5 changes: 5 additions & 0 deletions python/mlcroissant/mlcroissant/_src/core/optional.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,10 @@ def torchdata_datapipes(cls) -> types.ModuleType:
"""Cached torchdata module."""
return _try_import("torchdata.datapipes", package_name="torchdata")

@cached_class_property
def beam(cls):
"""Cached Apache Beam module."""
return _try_import("apache_beam", package_name="apache_beam")


deps = OptionalDependencies
36 changes: 1 addition & 35 deletions python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,41 +168,7 @@ def __iter__(self):
)

def beam_reader(self, pipeline: beam.Pipeline):
"""Returns an Apache Beam reader to generate the dataset using e.g. Spark.
Example of usage:
```python
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import mlcroissant as mlc
dataset = mlc.Dataset(
jsonld="https://huggingface.co/api/datasets/ylecun/mnist/croissant",
)
pipeline_options = PipelineOptions()
with beam.Pipeline(options=pipeline_options) as pipeline:
_ = dataset.records("mnist").beam_reader(pipeline)
```
Only streamable datasets can be used with Beam. A streamable dataset is a
dataset that can be generated by a linear sequence of operations - without joins
for example. This is the case for Hugging Face datasets. If there are branches,
we'd need a more complex Beam pipeline.
The sharding is done on the filtered files. This is currently optimized for
Hugging Face datasets, so it raises an error if the dataset is not a Hugging
Face dataset.
Args:
A Beam pipeline.
Returns:
A Beam PCollection with all the records.
Raises:
A ValueError if the dataset is not streamable.
"""
"""See ReadFromCroissant docstring."""
operations = self._filter_interesting_operations(self.filters)
execute_downloads(operations)
if not _is_streamable_dataset(operations):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd

from mlcroissant._src.core.issues import GenerationError
from mlcroissant._src.core.optional import deps
from mlcroissant._src.operation_graph.base_operation import Operation
from mlcroissant._src.operation_graph.base_operation import Operations
from mlcroissant._src.operation_graph.operations import FilterFiles
Expand Down Expand Up @@ -137,7 +138,7 @@ def read_all_files():
def execute_operations_in_beam(
pipeline: beam.Pipeline, record_set: str, operations: Operations
):
"""See beam_reader docstring."""
"""See ReadFromCroissant docstring."""
import apache_beam as beam

list_of_operations = _order_relevant_operations(operations, record_set)
Expand Down

0 comments on commit 52dc04b

Please sign in to comment.