-
Notifications
You must be signed in to change notification settings - Fork 3k
Add Distributed, Parallel Dataset Merging #2391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bb5676e
c87fd37
13a429e
927c6ac
eacb638
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -15,8 +15,10 @@ | |||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| import contextlib | ||||||
| import logging | ||||||
| import shutil | ||||||
| from concurrent.futures import ThreadPoolExecutor, as_completed | ||||||
| from pathlib import Path | ||||||
|
|
||||||
| import pandas as pd | ||||||
|
|
@@ -107,6 +109,7 @@ def update_meta_data( | |||||
| dst_meta, | ||||||
| meta_idx, | ||||||
| data_idx, | ||||||
| data_file_map, | ||||||
| videos_idx, | ||||||
| ): | ||||||
| """Updates metadata DataFrame with new chunk, file, and timestamp indices. | ||||||
|
|
@@ -127,8 +130,25 @@ def update_meta_data( | |||||
|
|
||||||
| df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"] | ||||||
| df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"] | ||||||
| df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] | ||||||
| df["data/file_index"] = df["data/file_index"] + data_idx["file"] | ||||||
| # Remap data chunk/file indices per-source-file using the actual destination | ||||||
| # file chosen during data aggregation. A flat offset is incorrect when | ||||||
| # multiple source files are concatenated into a single destination file. | ||||||
| if data_file_map: | ||||||
| new_data_chunk = [] | ||||||
| new_data_file = [] | ||||||
| for idx in df.index: | ||||||
| src_chunk = int(df.at[idx, "data/chunk_index"]) # original source file location | ||||||
| src_file = int(df.at[idx, "data/file_index"]) # original source file location | ||||||
| dst_chunk, dst_file = data_file_map.get( | ||||||
| (src_chunk, src_file), (src_chunk + data_idx["chunk"], src_file + data_idx["file"]) | ||||||
| ) | ||||||
| new_data_chunk.append(dst_chunk) | ||||||
| new_data_file.append(dst_file) | ||||||
| df["data/chunk_index"] = new_data_chunk | ||||||
| df["data/file_index"] = new_data_file | ||||||
| else: | ||||||
| df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] | ||||||
| df["data/file_index"] = df["data/file_index"] + data_idx["file"] | ||||||
| for key, video_idx in videos_idx.items(): | ||||||
| # Store original video file indices before updating | ||||||
| orig_chunk_col = f"videos/{key}/chunk_index" | ||||||
|
|
@@ -166,7 +186,7 @@ def update_meta_data( | |||||
| return df | ||||||
|
|
||||||
|
|
||||||
| def aggregate_datasets( | ||||||
| def _aggregate_datasets( | ||||||
| repo_ids: list[str], | ||||||
| aggr_repo_id: str, | ||||||
| roots: list[Path] | None = None, | ||||||
|
|
@@ -175,39 +195,24 @@ def aggregate_datasets( | |||||
| video_files_size_in_mb: float | None = None, | ||||||
| chunk_size: int | None = None, | ||||||
| ): | ||||||
| """Aggregates multiple LeRobot datasets into a single unified dataset. | ||||||
|
|
||||||
| This is the main function that orchestrates the aggregation process by: | ||||||
| 1. Loading and validating all source dataset metadata | ||||||
| 2. Creating a new destination dataset with unified tasks | ||||||
| 3. Aggregating videos, data, and metadata from all source datasets | ||||||
| 4. Finalizing the aggregated dataset with proper statistics | ||||||
| """Serial aggregation kernel: combines datasets into a destination dataset. | ||||||
|
|
||||||
| Args: | ||||||
| repo_ids: List of repository IDs for the datasets to aggregate. | ||||||
| aggr_repo_id: Repository ID for the aggregated output dataset. | ||||||
| roots: Optional list of root paths for the source datasets. | ||||||
| aggr_root: Optional root path for the aggregated dataset. | ||||||
| data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB) | ||||||
| video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB) | ||||||
| chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE) | ||||||
| This function performs a single-process aggregation. It assumes it is the | ||||||
| sole writer for its destination `aggr_root`. | ||||||
| """ | ||||||
| logging.info("Start aggregate_datasets") | ||||||
|
|
||||||
| if data_files_size_in_mb is None: | ||||||
| data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB | ||||||
| if video_files_size_in_mb is None: | ||||||
| video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB | ||||||
| if chunk_size is None: | ||||||
| chunk_size = DEFAULT_CHUNK_SIZE | ||||||
|
|
||||||
| all_metadata = ( | ||||||
| [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids] | ||||||
| if roots is None | ||||||
| else [ | ||||||
| LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False) | ||||||
| # Build metadata objects, supporting a per-dataset "root" that may be None. | ||||||
| # When root is provided we load from the local filesystem, otherwise from Hub cache. | ||||||
| if roots is None: | ||||||
| all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids] | ||||||
| else: | ||||||
| all_metadata = [ | ||||||
| ( | ||||||
| LeRobotDatasetMetadata(repo_id, root=root) | ||||||
| if root is not None | ||||||
| else LeRobotDatasetMetadata(repo_id) | ||||||
| ) | ||||||
| for repo_id, root in zip(repo_ids, roots, strict=False) | ||||||
| ] | ||||||
| ) | ||||||
| fps, robot_type, features = validate_all_metadata(all_metadata) | ||||||
| video_keys = [key for key in features if features[key]["dtype"] == "video"] | ||||||
|
|
||||||
|
|
@@ -237,9 +242,11 @@ def aggregate_datasets( | |||||
|
|
||||||
| for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"): | ||||||
| videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size) | ||||||
| data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size) | ||||||
| data_idx, data_file_map = aggregate_data( | ||||||
| src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size | ||||||
| ) | ||||||
|
|
||||||
| meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx) | ||||||
| meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx) | ||||||
|
|
||||||
| dst_meta.info["total_episodes"] += src_meta.total_episodes | ||||||
| dst_meta.info["total_frames"] += src_meta.total_frames | ||||||
|
|
@@ -248,6 +255,168 @@ def aggregate_datasets( | |||||
| logging.info("Aggregation complete.") | ||||||
|
|
||||||
|
|
||||||
| def aggregate_datasets( | ||||||
| repo_ids: list[str], | ||||||
| aggr_repo_id: str, | ||||||
| roots: list[Path] | None = None, | ||||||
| aggr_root: Path | None = None, | ||||||
| data_files_size_in_mb: float | None = None, | ||||||
| video_files_size_in_mb: float | None = None, | ||||||
| chunk_size: int | None = None, | ||||||
| num_workers: int | None = None, | ||||||
| tmp_root: Path | None = None, | ||||||
| ): | ||||||
| """Aggregates multiple LeRobot datasets into a single unified dataset. | ||||||
|
|
||||||
| This is the main function that orchestrates the aggregation process by: | ||||||
| 1. Loading and validating all source dataset metadata | ||||||
| 2. Creating a new destination dataset with unified tasks | ||||||
| 3. Aggregating videos, data, and metadata from all source datasets | ||||||
| 4. Finalizing the aggregated dataset with proper statistics | ||||||
|
|
||||||
| Args: | ||||||
| repo_ids: List of repository IDs for the datasets to aggregate. | ||||||
| aggr_repo_id: Repository ID for the aggregated output dataset. | ||||||
| roots: Optional list of root paths for the source datasets. | ||||||
| aggr_root: Optional root path for the aggregated dataset. | ||||||
| data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB) | ||||||
| video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB) | ||||||
| chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE) | ||||||
| num_workers: When > 1, performs a tree-based parallel reduction using a thread pool | ||||||
| tmp_root: Optional base directory to store intermediate reduction outputs | ||||||
| """ | ||||||
| logging.info("Start aggregate_datasets") | ||||||
|
|
||||||
| if data_files_size_in_mb is None: | ||||||
| data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB | ||||||
| if video_files_size_in_mb is None: | ||||||
| video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB | ||||||
| if chunk_size is None: | ||||||
| chunk_size = DEFAULT_CHUNK_SIZE | ||||||
|
|
||||||
| if num_workers is None or num_workers <= 1: | ||||||
| # Run aggregation sequentially | ||||||
| _aggregate_datasets( | ||||||
| repo_ids=repo_ids, | ||||||
| aggr_repo_id=aggr_repo_id, | ||||||
| aggr_root=aggr_root, | ||||||
| roots=roots, | ||||||
| data_files_size_in_mb=data_files_size_in_mb, | ||||||
| video_files_size_in_mb=video_files_size_in_mb, | ||||||
| chunk_size=chunk_size, | ||||||
| ) | ||||||
|
|
||||||
| # Uses a parallel fan-out/fan-in strategy when num_workers is provided | ||||||
| elif num_workers > 1: | ||||||
|
||||||
| elif num_workers > 1: | |
| else: |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -234,6 +234,7 @@ def merge_datasets( | |||||
| datasets: list[LeRobotDataset], | ||||||
| output_repo_id: str, | ||||||
| output_dir: str | Path | None = None, | ||||||
| num_workers: int | None = None, | ||||||
| ) -> LeRobotDataset: | ||||||
| """Merge multiple LeRobotDatasets into a single dataset. | ||||||
|
|
||||||
|
|
@@ -257,6 +258,7 @@ def merge_datasets( | |||||
| aggr_repo_id=output_repo_id, | ||||||
| roots=roots, | ||||||
| aggr_root=output_dir, | ||||||
| num_workers=num_workers, | ||||||
| ) | ||||||
|
|
||||||
| merged_dataset = LeRobotDataset( | ||||||
|
|
@@ -329,7 +331,7 @@ def modify_features( | |||||
|
|
||||||
| if repo_id is None: | ||||||
| repo_id = f"{dataset.repo_id}_modified" | ||||||
| output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id | ||||||
| output_dir = Path(output_dir, exists_ok=True) if output_dir is not None else HF_LEROBOT_HOME / repo_id | ||||||
|
||||||
| output_dir = Path(output_dir, exists_ok=True) if output_dir is not None else HF_LEROBOT_HOME / repo_id | |
| output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sequential path (when
num_workersis None or <= 1) does not return a value or log completion, but the parallel path returns explicitly at line 417. This creates inconsistent behavior. Either add a return statement after line 307, or move the logging and return after the entire if/elif block to ensure both paths behave consistently.