Skip to content

Implement Real-Time Action Chunking (RTC) for SmolVLA #1521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 64 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
416a8b8
Merge together proto files and refactor Async inference
helper2424 Jul 10, 2025
09c7f34
Fixup for Async inference
helper2424 Jul 10, 2025
e50a2fc
Drop not reuqired changes
helper2424 Jul 10, 2025
10c688f
Fix tests
helper2424 Jul 10, 2025
e060896
Drop old async files
helper2424 Jul 10, 2025
6b6727f
Merge branch 'main' into user/helper2424/updated_merge_proto
michel-aractingi Jul 11, 2025
58a82d3
Drop chunk_size param
helper2424 Jul 15, 2025
2c8c73b
Merge branch 'main' of https://github.com/huggingface/lerobot into us…
helper2424 Jul 15, 2025
ef232a6
Merge branch 'main' of https://github.com/huggingface/lerobot into us…
helper2424 Jul 16, 2025
baa0fcd
Merge branch 'main' of https://github.com/huggingface/lerobot into us…
helper2424 Jul 18, 2025
60ea278
Fix versions
helper2424 Jul 18, 2025
77158e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2025
53ba25b
Fix wrong fix
helper2424 Jul 19, 2025
c87951a
Refactor Async architecture
helper2424 Jul 11, 2025
7639101
Update client architecture
helper2424 Jul 11, 2025
a10dcb0
Fix sleeping logic
ben-z Jul 13, 2025
f325329
Add more types and debug runaway queue update
ben-z Jul 13, 2025
e0a6fb2
Debug robot client
ben-z Jul 13, 2025
2bf3252
Get observations in the policy client thread
ben-z Jul 13, 2025
b0f2082
Make the client logging more consistent
ben-z Jul 14, 2025
b43ccd7
Prepare for RTC implementation
ben-z Jul 16, 2025
5b95e6b
Implement RTC denoising with static d and s
ben-z Jul 17, 2025
acfa524
Remove debugging statements and move denoising to no_grad context to …
ben-z Jul 17, 2025
a65575c
Minor comment fix
ben-z Jul 17, 2025
c50bfb0
Hard-code RTC parameters for testing
ben-z Jul 17, 2025
f892e6f
Decrease logging in robot_client
ben-z Jul 17, 2025
b77fee0
Enable policy server to populate rtc_s and rtc_d
ben-z Jul 17, 2025
f4acc31
Optimize async client latency by using compression and tuning message…
ben-z Jul 17, 2025
eaf10c3
Improve latency by enabling compression in policy server
ben-z Jul 17, 2025
27f64da
Add option for model compilation to speed up inference
ben-z Jul 17, 2025
d5d4a7a
Reduce rtc_d now that we have lower latency
ben-z Jul 17, 2025
f6f1417
Tune end_s for chunk size 100
ben-z Jul 17, 2025
f215140
Set s_end to 75
ben-z Jul 18, 2025
7d033f1
Fix | None syntax
ben-z Jul 18, 2025
a21afa1
Add back chunk_size argument to send_bytes_in_chunks so different pro…
ben-z Jul 18, 2025
c8a888d
Make policy server work again after rebase
ben-z Jul 19, 2025
c500e92
Add inference_rtc_d and inference_rtc_soft_mask_length to SmolVLA par…
ben-z Jul 20, 2025
7c1e73e
Remove uneeded functionality
ben-z Jul 20, 2025
a5fea73
Rename steps_since_chunk_start to steps_since_last_chunk_start
ben-z Jul 20, 2025
5b160d0
Pass inference async stats to policies
ben-z Jul 20, 2025
8cd3060
Use async stats in smolvla
ben-z Jul 20, 2025
d9d374c
Fix policy server asyncstats import
ben-z Jul 20, 2025
4d97e2b
Fix smolvla emm
ben-z Jul 20, 2025
a79a96f
Use configurable soft mask length
ben-z Jul 20, 2025
f727ebb
Update log to show softmask length
ben-z Jul 20, 2025
ec78758
Use soft_mask instead of softmask in logging
ben-z Jul 20, 2025
49a05c9
Log soft mask length properly
ben-z Jul 20, 2025
9a055f1
Refactor smolvla to use the notation in the paper (t,d,s)
ben-z Jul 20, 2025
0b378e4
Fix bad merge
ben-z Jul 20, 2025
5c128ae
Use quotes to work around type issue
ben-z Jul 20, 2025
dca5ecd
Make all policies compatible with the new predict_action_chunk and se…
ben-z Jul 20, 2025
524d52b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2025
5fd21bc
Add docs for filter_args_recursive
ben-z Jul 20, 2025
3497da2
Rename s_end to s
ben-z Jul 20, 2025
2b930fd
Fix variable naming conflict
ben-z Jul 20, 2025
5fa967b
Fix softmask logic and use policy parameters properly
ben-z Jul 20, 2025
d1f66e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2025
5b0ec8a
Fix log prefix for policy server receive_bytes_in_chunks
ben-z Jul 20, 2025
c81a65e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2025
c06a46d
Add back warning for A_tau_d_err too high
ben-z Jul 20, 2025
ed5bd24
Add inference_rtc_debug flag for debug printing
ben-z Jul 20, 2025
1ca8fbe
Address ruff errors
ben-z Jul 20, 2025
ddbcc49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2025
cf28ede
Rename logged variable to reduce confusion
ben-z Jul 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ dependencies = [
pygame-dep = ["pygame>=2.5.1"]
placo-dep = ["placo>=0.9.6"]
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
grpcio-dep = ["grpcio==1.71.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]

# Motors
feetech = ["feetech-servo-sdk>=1.0.0"]
Expand All @@ -119,14 +119,14 @@ intelrealsense = [
# Policies
pi0 = ["lerobot[transformers-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "protobuf>=5.29.3", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]

# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]

# Development
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]

Expand Down
76 changes: 57 additions & 19 deletions src/lerobot/configs/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
from collections.abc import Sequence
from functools import wraps
from pathlib import Path
from typing import TypeVar

import draccus

from lerobot.utils.utils import has_method

T = TypeVar("T")

PATH_KEY = "path"
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"

Expand Down Expand Up @@ -151,6 +154,32 @@ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]


def filter_args_recursive(field_name: str, args: Sequence[str] | None = None) -> tuple[list[str], list[str]]:
"""
Filters arguments for a given field and all its subfields.

Args:
field_name (str): The name of the field to filter arguments for.
args (Sequence[str] | None): The sequence of command-line arguments to be filtered.
Defaults to None.

Returns:
tuple[list[str], list[str]]: A tuple containing two lists:
- The first list contains arguments that start with the field name or subfield name.
- The second list contains arguments that do not start with the field name or subfield name.
"""
with_field = []
without_field = []

for arg in args:
if arg.startswith(f"--{field_name}.") or arg.startswith(f"--{field_name}="):
with_field.append(arg)
else:
without_field.append(arg)

return with_field, without_field


def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
"""
Filters command-line arguments related to fields with specific path arguments.
Expand Down Expand Up @@ -184,7 +213,11 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
return filtered_args


def wrap(config_path: Path | None = None):
def parse(
config_class: type[T],
config_path: Path | str | None = None,
args: Sequence[str] | None = None,
) -> T:
"""
HACK: Similar to draccus.wrap but does three additional things:
- Will remove '.path' arguments from CLI in order to process them later on.
Expand All @@ -194,7 +227,29 @@ def wrap(config_path: Path | None = None):
their own subclasses of config classes, so that draccus can find the right class to instantiate
from the CLI '.type' arguments
"""
cli_args = args or sys.argv[1:]
plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
for plugin_cli_arg, plugin_path in plugin_args.items():
try:
load_plugin(plugin_path)
except PluginLoadError as e:
# add the relevant CLI arg to the error message
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
cli_args = filter_arg(plugin_cli_arg, cli_args)
config_path_cli = parse_arg("config_path", cli_args)
if has_method(config_class, "__get_path_fields__"):
path_fields = config_class.__get_path_fields__()
cli_args = filter_path_args(path_fields, cli_args)
if has_method(config_class, "from_pretrained") and config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = config_class.from_pretrained(config_path_cli, cli_args=cli_args)
else:
cfg = draccus.parse(config_class=config_class, config_path=config_path, args=cli_args)

return cfg


def wrap(config_path: Path | None = None):
def wrapper_outer(fn):
@wraps(fn)
def wrapper_inner(*args, **kwargs):
Expand All @@ -204,24 +259,7 @@ def wrapper_inner(*args, **kwargs):
cfg = args[0]
args = args[1:]
else:
cli_args = sys.argv[1:]
plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
for plugin_cli_arg, plugin_path in plugin_args.items():
try:
load_plugin(plugin_path)
except PluginLoadError as e:
# add the relevant CLI arg to the error message
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
cli_args = filter_arg(plugin_cli_arg, cli_args)
config_path_cli = parse_arg("config_path", cli_args)
if has_method(argtype, "__get_path_fields__"):
path_fields = argtype.__get_path_fields__()
cli_args = filter_path_args(path_fields, cli_args)
if has_method(argtype, "from_pretrained") and config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
cfg = parse(config_class=argtype, config_path=config_path)
response = fn(cfg, *args, **kwargs)
return response

Expand Down
8 changes: 8 additions & 0 deletions src/lerobot/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ def __getitem__(self, key: Any) -> Any: ...
class PolicyFeature:
type: FeatureType
shape: tuple


@dataclass
class AsyncStats:
# the number of ticks executed since the beginning of the last action chunk
steps_since_last_chunk_start: int
# round-trip inference latency in ticks.
inference_latency_steps: int
4 changes: 2 additions & 2 deletions src/lerobot/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def reset(self):
self._action_queue = deque([], maxlen=self.config.n_action_steps)

@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Select a single action given environment observations.

This method wraps `select_actions` in order to return one action at a time for execution in the
Expand All @@ -133,7 +133,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
return self._action_queue.popleft()

@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()

Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def reset(self):
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)

@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
Expand All @@ -112,7 +112,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
return actions

@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Select a single action given environment observations.

This method handles caching a history of observations and an action trajectory generated by the
Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/pi0/modeling_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,12 @@ def get_optim_params(self) -> dict:
return self.parameters()

@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0")

@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
"""Select a single action given environment observations.

This method wraps `select_actions` in order to return one action at a time for execution in the
Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/pi0fast/modeling_pi0fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ def _pi_aloha_encode_actions_inv(self, actions):
return actions

@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")

@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Select a single action given environment observations.

This method wraps `select_actions` in order to return one action at a time for execution in the
Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
raise NotImplementedError

@abc.abstractmethod
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.

Child classes using action chunking should use this method within `select_action` to form the action chunk
Expand All @@ -181,7 +181,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
raise NotImplementedError

@abc.abstractmethod
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).

When the model uses a history of observations, or outputs a sequence of actions, this method deals
Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/sac/modeling_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def reset(self):
pass

@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")

@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Select action for inference/evaluation"""

observations_features = None
Expand Down
4 changes: 2 additions & 2 deletions src/lerobot/policies/sac/reward_model/modeling_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,14 @@ def get_optim_params(self):
"""Return optimizer parameters for the policy."""
return self.parameters()

def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
raise NotImplementedError("Reward classifiers do not select actions")

def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not produce action chunks.
Expand Down
12 changes: 12 additions & 0 deletions src/lerobot/policies/smolvla/configuration_smolvla.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class SmolVLAConfig(PreTrainedConfig):
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False

# Whether to compile parts of the model using torch.compile. Improves inference speed but increases memory usage and startup time.
compile_model: bool = False

# Tokenizer
tokenizer_max_length: int = 48

Expand Down Expand Up @@ -101,6 +104,15 @@ class SmolVLAConfig(PreTrainedConfig):
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0

# Inference settings
inference_enable_rtc: bool = False # Whether to enable real-time action chunking (RTC): https://www.physicalintelligence.company/research/real_time_chunking
inference_rtc_d: int = (
-1
) # Inference delay (in action steps). If -1, it is set automatically based on roundtrip inference time.
inference_rtc_soft_mask_length: int = -1 # The length of the soft mask for RTC (in action steps). If -1, it is set automatically to chunk_size - d - t
inference_rtc_beta: float = 5.0 # RTC maximum guidance weight.
inference_rtc_debug: bool = False # Whether to enable debug mode for RTC. Will print debug information for RTC. RTC denoising will be slower.

def __post_init__(self):
super().__post_init__()

Expand Down
Loading