Skip to content

Remove config yaml for robot devices #594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3bb5876
Add draccus, create MainConfig
aliberts Dec 5, 2024
6b4cdab
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_11…
aliberts Dec 5, 2024
660d65d
WIP refactor train.py and ACT
aliberts Dec 18, 2024
ca34677
Add policies training presets
aliberts Dec 23, 2024
0eb7cdc
Update diffusion policy
aliberts Dec 23, 2024
32595f0
Add pusht and xarm env configs
aliberts Dec 23, 2024
66aa91a
Update tdmpc
aliberts Dec 23, 2024
14de4d9
Update vqbet
aliberts Dec 23, 2024
514cdae
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_11…
aliberts Dec 23, 2024
fa55e67
Fix poetry relax
aliberts Dec 23, 2024
6a2474e
Switch to full dataclass (Not tested)
Cadene Dec 24, 2024
ae89d9f
tests motors, cameras, robots are passing
Cadene Dec 25, 2024
14a117e
nit comment
Cadene Dec 26, 2024
ba31014
Add feature types to envs
aliberts Dec 27, 2024
87d92f9
Add EvalPipelineConfig, parse features from envs
aliberts Dec 27, 2024
9248d73
Merge remote-tracking branch 'origin/user/aliberts/2024_11_30_remove_…
Cadene Dec 27, 2024
72e1463
Fix channel -> channels
Cadene Dec 28, 2024
b3d20ae
WIP
Cadene Dec 28, 2024
e6ebf48
WIP before tests
Cadene Jan 1, 2025
bfbde19
WIP fix tests, they are runnable
Cadene Jan 1, 2025
b47de65
Fix some more tests
Cadene Jan 1, 2025
d180ae8
Fix policties unit tests
Cadene Jan 1, 2025
25ec3e7
fix test policies
Cadene Jan 1, 2025
1199d65
Improve test
Cadene Jan 2, 2025
acf44c4
WIP
Cadene Jan 4, 2025
ce93213
Small fix
Cadene Jan 6, 2025
0c7126d
fix
Cadene Jan 6, 2025
60a0b0c
Merge remote-tracking branch 'origin/user/aliberts/2024_11_30_remove_…
Cadene Jan 6, 2025
13af6e5
Merge remote-tracking branch 'origin/user/aliberts/2024_11_30_remove_…
Cadene Jan 6, 2025
1319a83
fix test_record_and_replay_and_policy
Cadene Jan 6, 2025
c0ecbee
Fix image_transforms
Cadene Jan 7, 2025
3b60310
Add remove dataset download TODO
Cadene Jan 8, 2025
06d44c0
fix tdmpc
Cadene Jan 8, 2025
0ddd9b7
Fix bug in delta_indices
Cadene Jan 8, 2025
f686e8a
Remove normalize_inputs and use None instead
Cadene Jan 8, 2025
e868178
diffusion backward works
Cadene Jan 8, 2025
028f280
comment out diffusion backward compatibility
Cadene Jan 9, 2025
abd1a99
Merge remote-tracking branch 'origin/user/aliberts/2024_11_30_remove_…
Cadene Jan 9, 2025
9cdd778
All tests are passing
Cadene Jan 9, 2025
2c17859
Address Simon's comments
Cadene Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions examples/2_evaluate_pretrained_policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.

It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
```bash
pip install -e ".[pusht]"`
```
"""

from pathlib import Path
Expand All @@ -10,19 +15,19 @@
import imageio
import numpy
import torch
from huggingface_hub import snapshot_download

from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

# Create a directory to store the video of the evaluation
output_directory = Path("outputs/eval/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)

# Download the diffusion policy for pusht environment
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):
pretrained_policy_path = "lerobot/diffusion_pusht"
# OR a path to a local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")

# TODO(alibert, rcadene): fix this file
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
policy.eval()

Expand Down
1 change: 1 addition & 0 deletions examples/3_train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
cfg.parse_features_from_dataset(dataset.meta)
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
policy.to(device)
Expand Down
1 change: 1 addition & 0 deletions examples/7_get_started_with_real_robot.md
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,7 @@ fps = 30
device = "cuda" # TODO: On Mac, use "mps" or "cpu"

ckpt_path = "outputs/train/act_koch_test/checkpoints/last/pretrained_model"
# TODO(alibert, rcadene): fix this file
policy = ACTPolicy.from_pretrained(ckpt_path)
policy.to(device)

Expand Down
1 change: 1 addition & 0 deletions examples/advanced/2_calculate_validation_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")

# TODO(alibert, rcadene): fix this file
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
policy.eval()
policy.to(device)
Expand Down
2 changes: 1 addition & 1 deletion examples/port_datasets/pusht_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"dtype": None,
"shape": (3, 96, 96),
"names": [
"channel",
"channels",
"height",
"width",
],
Expand Down
19 changes: 0 additions & 19 deletions lerobot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
],
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
}
available_envs = list(available_tasks_per_env.keys())

Expand Down Expand Up @@ -86,23 +85,6 @@
"lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image",
],
"dora_aloha_real": [
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
"lerobot/aloha_static_coffee_new",
"lerobot/aloha_static_cups_open",
"lerobot/aloha_static_fork_pick_up",
"lerobot/aloha_static_pingpong_test",
"lerobot/aloha_static_pro_pencil",
"lerobot/aloha_static_screw_driver",
"lerobot/aloha_static_tape",
"lerobot/aloha_static_thread_velcro",
"lerobot/aloha_static_towel",
"lerobot/aloha_static_vinh_cup",
"lerobot/aloha_static_vinh_cup_left",
"lerobot/aloha_static_ziploc_slide",
],
}

available_real_world_datasets = [
Expand Down Expand Up @@ -221,7 +203,6 @@
"xarm": ["tdmpc"],
"koch_real": ["act_koch_real"],
"aloha_real": ["act_aloha_real"],
"dora_aloha_real": ["act_aloha_real"],
}

env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
Expand Down
25 changes: 0 additions & 25 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,31 +70,11 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
Returns:
The LeRobotDataset.
"""
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
if cfg.env.type != "dora":
if isinstance(cfg.dataset.repo_id, str):
dataset_repo_ids = [cfg.dataset.repo_id] # single dataset
elif isinstance(cfg.dataset.repo_id, list):
dataset_repo_ids = cfg.dataset.repo_id # multiple datasets
else:
raise ValueError(
"Expected cfg.dataset.repo_id to be either a single string to load one dataset or a list of "
"strings to load multiple datasets."
)

for dataset_repo_id in dataset_repo_ids:
if cfg.env.type not in dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
f"environment ({cfg.env.type=})."
)

image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)

if isinstance(cfg.dataset.repo_id, str):
# TODO (aliberts): add 'episodes' arg from config after removing hydra
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
Expand Down Expand Up @@ -122,10 +102,5 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
# for key, stats_dict in cfg.override_dataset_stats.items():
# for stats_type, listconfig in stats_dict.items():
# # example of stats_type: min, max, mean, std
# stats = OmegaConf.to_container(listconfig, resolve=True)
# dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)

return dataset
132 changes: 34 additions & 98 deletions lerobot/common/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,18 @@ def __init__(
self.n_subset = n_subset
self.random_order = random_order

self.selected_transforms = None

def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1

selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
if not self.random_order:
selected_indices = selected_indices.sort().values

selected_transforms = [self.transforms[i] for i in selected_indices]
self.selected_transforms = [self.transforms[i] for i in selected_indices]

for transform in selected_transforms:
for transform in self.selected_transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)

Expand Down Expand Up @@ -138,67 +140,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)


# TODO(aliberts): Remove
def get_image_transforms(
brightness_weight: float = 1.0,
brightness_min_max: tuple[float, float] | None = None,
contrast_weight: float = 1.0,
contrast_min_max: tuple[float, float] | None = None,
saturation_weight: float = 1.0,
saturation_min_max: tuple[float, float] | None = None,
hue_weight: float = 1.0,
hue_min_max: tuple[float, float] | None = None,
sharpness_weight: float = 1.0,
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
):
def check_value(name, weight, min_max):
if min_max is not None:
if len(min_max) != 2:
raise ValueError(
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
)
if weight < 0.0:
raise ValueError(
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
)

check_value("brightness", brightness_weight, brightness_min_max)
check_value("contrast", contrast_weight, contrast_min_max)
check_value("saturation", saturation_weight, saturation_min_max)
check_value("hue", hue_weight, hue_min_max)
check_value("sharpness", sharpness_weight, sharpness_min_max)

weights = []
transforms = []
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
if contrast_min_max is not None and contrast_weight > 0.0:
weights.append(contrast_weight)
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
if saturation_min_max is not None and saturation_weight > 0.0:
weights.append(saturation_weight)
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
if hue_min_max is not None and hue_weight > 0.0:
weights.append(hue_weight)
transforms.append(v2.ColorJitter(hue=hue_min_max))
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))

n_subset = len(transforms)
if max_num_transforms is not None:
n_subset = min(n_subset, max_num_transforms)

if n_subset == 0:
return v2.Identity()
else:
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)


@dataclass
class ImageTransformConfig:
"""
Expand Down Expand Up @@ -234,79 +175,74 @@ class ImageTransformsConfig:
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: bool = False
tfs: list[ImageTransformConfig] = field(
default_factory=lambda: [
ImageTransformConfig(
tfs: dict[str, ImageTransformConfig] = field(
default_factory=lambda: {
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.8, 1.2)},
),
ImageTransformConfig(
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.8, 1.2)},
),
ImageTransformConfig(
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 1.5)},
),
ImageTransformConfig(
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (-0.05, 0.05)},
),
ImageTransformConfig(
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
]
}
)


def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")


Comment on lines +209 to +219
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder: Move this in tests/ as it's only used for that purpose

class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration."""

_registry = {
"Identity": v2.Identity,
"ColorJitter": v2.ColorJitter,
"SharpnessJitter": SharpnessJitter,
}

def __init__(self, cfg: ImageTransformsConfig) -> None:
super().__init__()
self._cfg = cfg

weights = []
transforms = []
for tf_cfg in cfg.tfs:
self.weights = []
self.transforms = {}
for tf_name, tf_cfg in cfg.tfs.items():
if tf_cfg.weight <= 0.0:
continue

transform_cls = self._registry.get(tf_cfg.type)
if transform_cls is None:
available_transforms = ", ".join(self._registry.keys())
raise ValueError(
f"Transform '{tf_cfg.type}' not found in the registry. "
f"Available transforms are: {available_transforms}"
)

# Instantiate the transform
transform_instance = transform_cls(**tf_cfg.kwargs)
transforms.append(transform_instance)
weights.append(tf_cfg.weight)
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
self.weights.append(tf_cfg.weight)

n_subset = min(len(transforms), cfg.max_num_transforms)
n_subset = min(len(self.transforms), cfg.max_num_transforms)
if n_subset == 0 or not cfg.enable:
self.transform = v2.Identity()
self.tf = v2.Identity()
else:
self.transform = RandomSubsetApply(
transforms=transforms,
p=weights,
self.tf = RandomSubsetApply(
transforms=list(self.transforms.values()),
p=self.weights,
n_subset=n_subset,
random_order=cfg.random_order,
)

def forward(self, *inputs: Any) -> Any:
return self.transform(*inputs)
return self.tf(*inputs)
2 changes: 1 addition & 1 deletion lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def check_delta_timestamps(
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
delta_indices[key] = [round(d * fps) for d in delta_ts]

return delta_indices

Expand Down
8 changes: 4 additions & 4 deletions lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
from textwrap import dedent

from lerobot import available_datasets
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig

LOCAL_DIR = Path("data/")

ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
ALOHA_MOBILE_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG),
"robot_config": AlohaRobotConfig(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a use_shadow_motors option here

"license": "mit",
"url": "https://mobile-aloha.github.io/",
"paper": "https://arxiv.org/abs/2401.02117",
Expand All @@ -45,7 +45,7 @@
}""").lstrip(),
}
ALOHA_STATIC_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG),
"robot_config": AlohaRobotConfig(),
"license": "mit",
"url": "https://tonyzhaozh.github.io/aloha/",
"paper": "https://arxiv.org/abs/2304.13705",
Expand Down
Loading
Loading