Skip to content

Commit 17070e4

Browse files
committed
Add config for RTC
1 parent 36f3ee6 commit 17070e4

File tree

16 files changed

+1140
-242
lines changed

16 files changed

+1140
-242
lines changed

src/lerobot/configs/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,10 @@ def __getitem__(self, key: Any) -> Any: ...
4040
class PolicyFeature:
4141
type: FeatureType
4242
shape: tuple
43+
44+
45+
class RTCAttentionSchedule(str, Enum):
46+
ZEROS = "ZEROS"
47+
ONES = "ONES"
48+
LINEAR = "LINEAR"
49+
EXP = "EXP"

src/lerobot/policies/act/modeling_act.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def reset(self):
108108
self._action_queue = deque([], maxlen=self.config.n_action_steps)
109109

110110
@torch.no_grad()
111-
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
111+
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
112112
"""Select a single action given environment observations.
113113
114114
This method wraps `select_actions` in order to return one action at a time for execution in the
@@ -133,7 +133,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
133133
return self._action_queue.popleft()
134134

135135
@torch.no_grad()
136-
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
136+
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
137137
"""Predict a chunk of actions given environment observations."""
138138
self.eval()
139139

src/lerobot/policies/diffusion/modeling_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def reset(self):
100100
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
101101

102102
@torch.no_grad()
103-
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
103+
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
104104
"""Predict a chunk of actions given environment observations."""
105105
# stack n latest observations from the queue
106106
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
@@ -112,7 +112,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
112112
return actions
113113

114114
@torch.no_grad()
115-
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
115+
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
116116
"""Select a single action given environment observations.
117117
118118
This method handles caching a history of observations and an action trajectory generated by the

src/lerobot/policies/pi0/modeling_pi0.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,12 +360,12 @@ def from_pretrained(cls, *args, **kwargs):
360360
return super().from_pretrained(*args, **kwargs)
361361

362362
@torch.no_grad()
363-
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
363+
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
364364
"""Predict a chunk of actions given environment observations."""
365365
raise NotImplementedError("Currently not implemented for PI0")
366366

367367
@torch.no_grad()
368-
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
368+
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
369369
"""Select a single action given environment observations.
370370
371371
This method wraps `select_actions` in order to return one action at a time for execution in the

src/lerobot/policies/pi0fast/modeling_pi0fast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,12 @@ def _pi_aloha_encode_actions_inv(self, actions):
204204
return actions
205205

206206
@torch.no_grad()
207-
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
207+
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
208208
"""Predict a chunk of actions given environment observations."""
209209
raise NotImplementedError("Currently not implemented for PI0FAST")
210210

211211
@torch.no_grad()
212-
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
212+
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
213213
"""Select a single action given environment observations.
214214
215215
This method wraps `select_actions` in order to return one action at a time for execution in the

src/lerobot/policies/pretrained.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
181181
raise NotImplementedError
182182

183183
@abc.abstractmethod
184-
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
184+
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
185185
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
186186
187187
Child classes using action chunking should use this method within `select_action` to form the action chunk
@@ -190,7 +190,7 @@ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
190190
raise NotImplementedError
191191

192192
@abc.abstractmethod
193-
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
193+
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
194194
"""Return one action to run in the environment (potentially in batch mode).
195195
196196
When the model uses a history of observations, or outputs a sequence of actions, this method deals

src/lerobot/policies/rtc/README.md

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,6 @@ RTC can be integrated with any policy that supports flow mathicng for chunking:
1515
- **SmolVLA**: Vision-language-action model with RTC support
1616
- **Pi0**: Action prediction model with adaptive chunking
1717

18-
## Configuration
19-
20-
RTC behavior is configured through the `AdaptiveInferenceConfig` class located in `src/lerobot/policies/rtc_config.py`. Key parameters include:
21-
22-
- `soft_mask_length`: Number of actions to soft mask in overlap regions
23-
- `beta`: Maximum guidance weight for prefix attention
24-
- `prefix_attention_schedule`: Attention weight scheduling strategy
25-
- `guidance_scale`: Scale factor for guidance correction
26-
2718
## Original Implementation
2819

2920
This implementation is based on Physical Intelligence's Kinetix RTC:
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""
18+
Real Time Chunking (RTC) and Bidirectional Decoding (BID) configuration classes.
19+
20+
Based on:
21+
- Real Time Chunking: https://www.physicalintelligence.company/research/real_time_chunking
22+
"""
23+
24+
from dataclasses import dataclass
25+
26+
from lerobot.configs.types import RTCAttentionSchedule
27+
28+
29+
@dataclass
30+
class RTCConfig:
31+
"""Configuration for Real Time Chunking (RTC) inference.
32+
33+
RTC improves real-time inference by treating chunk generation as an inpainting problem,
34+
strategically handling overlapping timesteps between action chunks using prefix attention.
35+
"""
36+
37+
# Core RTC settings
38+
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
39+
max_guidance_weight: float = 5.0
40+
execution_horizon: int = 10
41+
42+
def __post_init__(self):
43+
"""Validate RTC configuration parameters."""
44+
if self.max_guidance_weight <= 0:
45+
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")

0 commit comments

Comments
 (0)