An unofficial pytorch dataloader for Open-X Embodiment Datasets.
This README will guide you to integrate the Open-X Embodiment Datasets into your PyTorch project. For a native TensorFlow experience, please check the official repo.
- Check available datasets and their corresponding metadata in the dataset spreadsheet
- Warning The images in
utokyo_saytap_converted_externally_to_rlds
seem to be corrupted.
- Set your preferred download destination
download_dst
in generate_download_script.py and confirm the datasets you want to download. By default, the Python script will create a shell script that downloads all 53 datasets, amounting to a total size of approximately 4.5TB. - Follow this guide to setup
gsutil
- Generate the shell script and start to download:
python3 generate_download_script.py chmod +x download.sh ./download.sh
This section was last updated on 1/19/2024.
-
Install python dependence
pip3 install -r requirements.txt
-
If your machine has enough RAM to hold the whole dataset, you can init the dataset with
class OpenXDataset(Dataset)
inopen_x_dataset_pytorch.py
. A quick example:d = OpenXDataset( tf_dir='datasets/asu_table_top_converted_externally_to_rlds/0.1.0/', fetch_pattern=r'.*image.*', sample_length=8, ) print(d)
-
tf_dir
: full directory containing the downloaded dataset, including the version number. -
fetch_pattern
: regular expression utilized to specify the data you wish to retrieve. Defaults tor'steps*'
. The example above only retrieves visual observations. -
sample_length
: number of transitions per sample. If set to2
, the returned sample will be$[s_1, s_2]$ .
The last several lines of the output of the code above:
========== Total episodes: 110 Total samples: 1433503 ========== Output keys: - steps/observation/image Masked keys: - steps/observation/state_vel - steps/ground_truth_states/bread - steps/is_first - steps/ground_truth_states/coke - steps/ground_truth_states/cube - steps/language_embedding - steps/is_terminal - steps/is_last - steps/discount - steps/ground_truth_states/EE - steps/language_instruction - steps/ground_truth_states/pepsi - steps/ground_truth_states/milk - steps/observation/state - steps/goal_object - steps/action - episode_metadata/file_path - steps/ground_truth_states/bottle - steps/action_delta - steps/action_inst - steps/reward
__getitem__()
returns a dictionary where the keys correspond tofetch_pattern
. The associated value for each key will be either a tensor of size(sample_length, *original feature shape)
1 or a list withsample_length
elements. -
-
If the machine does not have enough RAM: use
class IterableOpenXDataset(IterableDataset)
inopen_x_dataset_pytorch.py
instead. It takes the same input parameters as the one above, though it does not maintain the total number of samples.d = IterableOpenXDataset( tf_dir='datasets/asu_table_top_converted_externally_to_rlds/0.1.0/', fetch_pattern=r'.*image.*', sample_length=8, ) print(d)
- Filter out the invalid episodes according to the dataset format
I really appreciate the substantial open-sourcing effort contributed by the creators of this extensive dataset. Thank Jinghuan Shang for valuable discussions.
Footnotes
-
When the feature is an image, the tensor will have a shape of
(sample_length, C, H, W)
instead. ↩