Skip to content

Remove local_files_only and use codebase_version instead of branches #734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2 changes: 1 addition & 1 deletion examples/port_datasets/pusht_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,5 +223,5 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
main(raw_dir, repo_id=repo_id, mode=mode)

# Uncomment if you want to load the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
# dataset = LeRobotDataset(repo_id=repo_id)
# breakpoint()
40 changes: 40 additions & 0 deletions lerobot/common/datasets/backward_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import packaging.version

V2_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.

We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```

A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
sweatshirt.", ...

If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""

V21_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
```
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
```

If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""


class BackwardCompatibilityError(Exception):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)
4 changes: 2 additions & 2 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
)

if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, revision=cfg.dataset.revision)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset.repo_id,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
local_files_only=cfg.dataset.local_files_only,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
Expand Down
129 changes: 73 additions & 56 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import logging
import os
import shutil
from functools import cached_property
from pathlib import Path
from typing import Callable

Expand All @@ -27,6 +26,8 @@
import torch.utils
from datasets import load_dataset
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from packaging import version

from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
Expand All @@ -41,14 +42,13 @@
check_frame_features,
check_timestamps_sync,
check_version_compatibility,
create_branch,
create_empty_dataset_info,
create_lerobot_dataset_card,
get_delta_indices,
get_episode_data_index,
get_features_from_robot,
get_hf_features_from_features,
get_hub_safe_version,
get_safe_revision,
hf_transform_to_torch,
load_episodes,
load_episodes_stats,
Expand Down Expand Up @@ -79,30 +79,35 @@ def __init__(
self,
repo_id: str,
root: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
force_cache_sync: bool = False,
):
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
self.local_files_only = local_files_only

# Load metadata
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
try:
if force_cache_sync:
raise FileNotFoundError
self.load_metadata()
except (FileNotFoundError, NotADirectoryError):
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.revision = get_safe_revision(self.repo_id, self.revision)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()

check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)

def load_metadata(self):
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root)
try:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
except FileNotFoundError:
logging.warning(
f"""'episodes_stats.jsonl' not found. Using global dataset stats for each episode instead.
Convert your dataset stats to the new format using this command:
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={self.repo_id} """
)
if version.parse(self._version) < version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))

def pull_from_repo(
self,
Expand All @@ -112,17 +117,12 @@ def pull_from_repo(
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._hub_version,
revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
)

@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)

@property
def _version(self) -> str:
"""Codebase version used to create this dataset."""
Expand Down Expand Up @@ -342,7 +342,7 @@ def create(
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
obj.local_files_only = True
obj.revision = None
return obj


Expand All @@ -355,8 +355,9 @@ def __init__(
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
):
"""
Expand All @@ -366,7 +367,7 @@ def __init__(
- On your local disk in the 'root' folder. This is typically the case when you recorded your
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
internet connection), in that case, use local_files_only=True.
internet connection).

- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
Expand Down Expand Up @@ -448,11 +449,15 @@ def __init__(
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. Defaults to current codebase version tag.
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
are already present in the local cache, this will be faster. However, files loaded might not
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
False.
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
will be made. Defaults to False.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
Expand All @@ -463,9 +468,9 @@ def __init__(
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "pyav"
self.delta_indices = None
self.local_files_only = local_files_only

# Unused attributes
self.image_writer = None
Expand All @@ -474,17 +479,24 @@ def __init__(
self.root.mkdir(exist_ok=True, parents=True)

# Load metadata
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
if self.episodes is not None and self.meta._version == CODEBASE_VERSION:
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
if self.episodes is not None and version.parse(self.meta._version) >= version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats)

# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)

# Load actual data
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
try:
if force_cache_sync:
raise FileNotFoundError
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_revision(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()

self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)

# Check timestamps
Expand All @@ -501,7 +513,6 @@ def __init__(
def push_to_hub(
self,
branch: str | None = None,
create_card: bool = True,
tags: list | None = None,
license: str | None = "apache-2.0",
push_videos: bool = True,
Expand All @@ -528,7 +539,13 @@ def push_to_hub(
exist_ok=True,
)
if branch:
create_branch(repo_id=self.repo_id, branch=branch, repo_type="dataset")
hub_api.create_branch(
repo_id=self.repo_id,
branch=branch,
revision=self.revision,
repo_type="dataset",
exist_ok=True,
)

hub_api.upload_folder(
repo_id=self.repo_id,
Expand All @@ -538,15 +555,12 @@ def push_to_hub(
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
if create_card:
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)

if not branch:
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")

def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
Expand All @@ -555,11 +569,10 @@ def pull_from_repo(
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.meta._hub_version,
revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
)

def download_episodes(self, download_videos: bool = True) -> None:
Expand All @@ -573,17 +586,23 @@ def download_episodes(self, download_videos: bool = True) -> None:
files = None
ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None:
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
if len(self.meta.video_keys) > 0 and download_videos:
video_files = [
str(self.meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self.meta.video_keys
for ep_idx in self.episodes
]
files += video_files
files = self.get_episodes_file_paths()

self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)

def get_episodes_file_paths(self) -> list[Path]:
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self.meta.video_keys) > 0:
video_files = [
str(self.meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self.meta.video_keys
for ep_idx in episodes
]
fpaths += video_files

return fpaths

def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None:
Expand Down Expand Up @@ -991,7 +1010,7 @@ def create(
)
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
obj.local_files_only = obj.meta.local_files_only
obj.revision = None
obj.tolerance_s = tolerance_s
obj.image_writer = None

Expand Down Expand Up @@ -1033,7 +1052,6 @@ def __init__(
delta_timestamps: dict[list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
):
super().__init__()
Expand All @@ -1051,7 +1069,6 @@ def __init__(
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
local_files_only=local_files_only,
video_backend=video_backend,
)
for repo_id in repo_ids
Expand Down
Loading
Loading