[RLlib] PPO with LSTM, shared vf layers, and custom tokenizer: KeyError: 'infos' in SampleBatch._batch_slice #45666
Open
Description
What happened + What you expected to happen
Running PPO with LSTM and a custom tokenizer (set via a catalog).
This results in the following error:
File "ppo_lstm_encoder_sample.py", line 104, in run
result = algo.train()
^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 331, in train
raise skipped from exception_cause(skipped)
File "lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 328, in train
result = self.step()
^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 873, in step
train_results, train_iter_ctx = self._run_one_training_iteration()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 3154, in _run_one_training_iteration
results = self.training_step()
^^^^^^^^^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 428, in training_step
return self._training_step_old_and_hybrid_api_stacks()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 578, in _training_step_old_and_hybrid_api_stacks
train_results = self.learner_group.update_from_batch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py", line 267, in update_from_batch
return self._update(
^^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py", line 373, in _update
_learner_update(
File "lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py", line 350, in _learner_update
return learner.update_from_batch(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/core/learner/learner.py", line 1114, in update_from_batch
return self._update_from_batch_or_episodes(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_learner.py", line 68, in _update_from_batch_or_episodes
return super()._update_from_batch_or_episodes(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/core/learner/learner.py", line 1367, in _update_from_batch_or_episodes
for tensor_minibatch in batch_iter(batch, minibatch_size, num_iters):
File "lib/python3.11/site-packages/ray/rllib/utils/minibatch_utils.py", line 145, in __iter__
minibatch[module_id] = concat_samples(samples_to_concat)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py", line 1638, in concat_samples
*[s[k] for s in concated_samples],
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py", line 1638, in <listcomp>
*[s[k] for s in concated_samples],
~^^^
File "lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py", line 951, in __getitem__
value = dict.__getitem__(self, key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'infos'
Seems like a regression from #44082.
Versions / Dependencies
ray version 2.23.0
Python 3.11
Reproduction script
import ray
from typing import Dict
from dataclasses import dataclass
from gymnasium import Space
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.examples.envs.classes.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.core.models.base import Encoder, ENCODER_OUT
from ray.rllib.core.models.configs import ModelConfig
from ray.rllib.core.models.torch.base import TorchModel
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.models.specs.specs_base import TensorSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.policy.sample_batch import SampleBatch
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
torch, nn = try_import_torch()
class CustomTorchTokenizer(TorchModel, Encoder):
def __init__(self, config) -> None:
TorchModel.__init__(self, config)
Encoder.__init__(self, config)
self.net = nn.Sequential(
nn.Linear(config.input_dims[0], config.output_dims[0]),
)
# Since we use this model as a tokenizer, we need to define it's output
# dimensions so that we know the input dim for the recurent cells that follow.
def get_output_specs(self):
# In this example, the output dim will be 64, but we still fetch it from
# config so that this code is more reusable.
output_dim = self.config.output_dims[0]
return SpecDict(
{ENCODER_OUT: TensorSpec("b, d", d=output_dim, framework="torch")}
)
def _forward(self, inputs: dict, **kwargs):
return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])}
@dataclass
class CustomTokenizerConfig(ModelConfig):
output_dims: tuple = None
def build(self, framework):
if framework == "torch":
return CustomTorchTokenizer(self)
else:
raise ValueError(f"Unsupported framework `{framework}`!")
class CustomPPOCatalog(PPOCatalog):
@classmethod
def get_tokenizer_config(
cls,
observation_space: Space,
model_config_dict: Dict,
view_requirements: Dict[str, ViewRequirement] | None = None,
) -> ModelConfig:
return CustomTokenizerConfig(
input_dims=observation_space.shape,
output_dims=(64,),
)
def run():
ray.init()
register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
model_config_dict = {
"use_lstm": True,
"lstm_cell_size": 32,
"vf_share_layers": True,
}
rlm_spec = SingleAgentRLModuleSpec(
catalog_class=CustomPPOCatalog,
model_config_dict=model_config_dict,
)
algo = (
PPOConfig()
.experimental(_disable_preprocessor_api=True)
.api_stack(enable_rl_module_and_learner=True)
.resources(num_gpus=0)
.env_runners(num_env_runners=0)
.environment(
env="RepeatAfterMeEnv",
env_config={"continuous": True, "repeat_delay": 2},
)
.training(
mini_batch_size_per_learner=8,
train_batch_size=8,
)
.rl_module(rl_module_spec=rlm_spec)
.build()
)
for i in range(100):
result = algo.train()
# print(pretty_print(result))
print(
f"Iteration {i+1:03d}: {result['env_runners']['episode_reward_mean']:.2f}"
)
if __name__=="__main__":
run()
Issue Severity
None