Skip to content

[Visualizer] support v2 dataset #485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

292 changes: 219 additions & 73 deletions lerobot/scripts/visualize_dataset_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,27 @@
"""

import argparse
import csv
import logging
import shutil
from io import StringIO
from pathlib import Path
import re
import tempfile
import os

import tqdm
from flask import Flask, redirect, render_template, url_for
import numpy as np
from flask import Flask, redirect, render_template, url_for, request
from datasets import load_dataset

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import init_logging
from lerobot import available_datasets


def run_server(
dataset: LeRobotDataset,
episodes: list[int],
dataset: LeRobotDataset | dict | None,
episodes: list[int] | None,
host: str,
port: str,
static_folder: Path,
Expand All @@ -76,10 +83,54 @@ def run_server(
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache

@app.route("/")
def index():
# home page redirects to the first episode page
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
first_episode_id = episodes[0]
def hommepage(dataset=dataset):
if dataset:
dataset_namespace, dataset_name = (
dataset.repo_id if isinstance(dataset, LeRobotDataset) else dataset["repo_id"]
).split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=0,
)
)

dataset_param, episode_param, time_param = None, None, None
all_params = request.args
if "dataset" in all_params:
dataset_param = all_params["dataset"]
if "episode" in all_params:
episode_param = int(all_params["episode"])
if "t" in all_params:
time_param = all_params["t"]

if dataset_param:
dataset_namespace, dataset_name = dataset_param.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=episode_param if episode_param is not None else 0,
)
)

featured_datasets = [
"lerobot/aloha_static_cups_open",
"lerobot/columbia_cairlab_pusht_real",
"lerobot/taco_play",
]
return render_template(
"visualize_dataset_homepage.html",
featured_datasets=featured_datasets,
lerobot_datasets=available_datasets,
)

@app.route("/<string:dataset_namespace>/<string:dataset_name>")
def show_first_episode(dataset_namespace, dataset_name):
first_episode_id = 0
return redirect(
url_for(
"show_episode",
Expand All @@ -90,30 +141,78 @@ def index():
)

@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id):
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
dataset = get_dataset_info(repo_id)
except FileNotFoundError:
return "Make sure your convert your LeRobotDataset to v2 & above."
dataset_version = (
dataset._version if isinstance(dataset, LeRobotDataset) else dataset["codebase_version"]
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
major_version = int(match.group(1))
if major_version < 2:
return "Make sure your convert your LeRobotDataset to v2 & above."

episode_data_csv_str = get_episode_data_csv_str(dataset, episode_id)
dataset_info = {
"repo_id": dataset.repo_id,
"num_samples": dataset.num_samples,
"num_episodes": dataset.num_episodes,
"fps": dataset.fps,
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_samples
if isinstance(dataset, LeRobotDataset)
else dataset["total_frames"],
"num_episodes": dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset["total_episodes"],
"fps": dataset.fps if isinstance(dataset, LeRobotDataset) else dataset["fps"],
}
video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys]
tasks = dataset.episode_dicts[episode_id]["tasks"]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.name}
for video_path in video_paths
]
if isinstance(dataset, LeRobotDataset):
video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.name}
for video_path in video_paths
]
tasks = dataset.episode_dicts[episode_id]["tasks"]
else:
videos_info = [
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset["videos"]["videos_path"].format(
episode_chunk=int(episode_id) // dataset["chunks_size"],
video_key=video_key,
episode_index=episode_id,
),
"filename": video_key,
}
for video_key in dataset["video_keys"]
]
tasks_jsonl = load_dataset(
"json",
data_files=f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl",
split="train",
)
filtered_tasks_jsonl = tasks_jsonl.filter(lambda x: x["episode_index"] == episode_id)
tasks = filtered_tasks_jsonl["tasks"][0]

videos_info[0]["language_instruction"] = tasks

ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
if episodes is None:
episodes = list(
range(
dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset["total_episodes"]
)
)

return render_template(
"visualize_dataset_template.html",
episode_id=episode_id,
episodes=episodes,
dataset_info=dataset_info,
videos_info=videos_info,
ep_csv_url=ep_csv_url,
has_policy=False,
episode_data_csv_str=episode_data_csv_str,
)

app.run(host=host, port=port)
Expand All @@ -124,46 +223,70 @@ def get_ep_csv_fname(episode_id: int):
return ep_csv_fname


def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
def get_episode_data_csv_str(dataset: LeRobotDataset | dict, episode_index):
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]

has_state = "observation.state" in dataset.hf_dataset.features
has_action = "action" in dataset.hf_dataset.features
has_state = "observation.state" in (
dataset.hf_dataset.features if isinstance(dataset, LeRobotDataset) else dataset["keys"]
)
has_action = "action" in (
dataset.hf_dataset.features if isinstance(dataset, LeRobotDataset) else dataset["keys"]
)

# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = dataset.shapes["observation.state"]
dim_state = (dataset.shapes if isinstance(dataset, LeRobotDataset) else dataset["shapes"])[
"observation.state"
]
header += [f"state_{i}" for i in range(dim_state)]
if has_action:
dim_action = dataset.shapes["action"]
dim_action = (dataset.shapes if isinstance(dataset, LeRobotDataset) else dataset["shapes"])["action"]
header += [f"action_{i}" for i in range(dim_action)]

columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]

rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if isinstance(dataset, LeRobotDataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
columns = ["timestamp"]
if has_state:
row += data[i]["observation.state"].tolist()
columns += ["observation.state"]
if has_action:
row += data[i]["action"].tolist()
rows.append(row)
columns += ["action"]
data = dataset.hf_dataset.select(range(from_idx, to_idx)).select_columns(columns).with_format("numpy")
rows = np.hstack(
(np.expand_dims(data["timestamp"], axis=1), *[data[col] for col in columns[1:]])
).tolist()
else:
repo_id = dataset["repo_id"]
columns = ["timestamp"]
if "observation.state" in dataset["keys"]:
columns.append("observation.state")
if "action" in dataset["keys"]:
columns.append("action")
episode_parquet = load_dataset(
"parquet",
data_files=f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset["data_path"].format(
episode_chunk=int(episode_index) // dataset["chunks_size"], episode_index=episode_index
),
split="train",
)
d = episode_parquet.select_columns(columns).with_format("numpy")
data = d.to_pandas()
rows = np.hstack(
(np.expand_dims(data["timestamp"], axis=1), *[np.vstack(data[col]) for col in columns[1:]])
).tolist()

output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w") as f:
f.write(",".join(header) + "\n")
for row in rows:
row_str = [str(col) for col in row]
f.write(",".join(row_str) + "\n")
# Convert data to CSV string
csv_buffer = StringIO()
csv_writer = csv.writer(csv_buffer)
# Write header
csv_writer.writerow(header)
# Write data rows
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()

return csv_string


def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
Expand All @@ -188,10 +311,21 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")


def get_dataset_info(repo_id: str) -> dict:
dataset_info = load_dataset(
"json",
data_files=f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json",
split="train",
)[0]
dataset_info["repo_id"] = repo_id
return dataset_info


def visualize_dataset_html(
repo_id: str,
repo_id: str | None = None,
root: Path | None = None,
episodes: list[int] = None,
load_from_hf_hub: bool = False,
episodes: list[int] | None = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
Expand All @@ -200,13 +334,12 @@ def visualize_dataset_html(
) -> Path | None:
init_logging()

dataset = LeRobotDataset(repo_id, root=root)

if len(dataset.image_keys) > 0:
raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.")
template_dir = Path(__file__).resolve().parent.parent / "templates"

if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
# Create a temporary directory that will be automatically cleaned up
temp_base = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
output_dir = temp_base if not repo_id else os.path.join(temp_base, repo_id)

output_dir = Path(output_dir)
if output_dir.exists():
Expand All @@ -217,28 +350,35 @@ def visualize_dataset_html(

output_dir.mkdir(parents=True, exist_ok=True)

# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())

template_dir = Path(__file__).resolve().parent.parent / "templates"
if not repo_id:
if serve:
run_server(
dataset=None,
episodes=None,
host=host,
port=port,
static_folder=static_dir,
template_folder=template_dir,
)
else:
dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id)

if episodes is None:
episodes = list(range(dataset.num_episodes))
image_keys = dataset.image_keys if isinstance(dataset, LeRobotDataset) else dataset["image_keys"]
if len(image_keys) > 0:
raise NotImplementedError(f"Image keys ({image_keys=}) are currently not supported.")

logging.info("Writing CSV files")
for episode_index in tqdm.tqdm(episodes):
# write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname(episode_index)
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
if isinstance(dataset, LeRobotDataset):
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())

if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)


def main():
Expand All @@ -247,7 +387,7 @@ def main():
parser.add_argument(
"--repo-id",
type=str,
required=True,
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
Expand All @@ -256,6 +396,12 @@ def main():
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--load-from-hf-hub",
type=int,
default=0,
help="Load videos and parquet files from HF Hub rather than local system.",
)
parser.add_argument(
"--episodes",
type=int,
Expand Down
Loading