Skip to content

[RLlib] PPO with LSTM, shared vf layers, and custom tokenizer: KeyError: 'infos' in SampleBatch._batch_slice #45666

Open
@jfuechsl

Description

What happened + What you expected to happen

Running PPO with LSTM and a custom tokenizer (set via a catalog).
This results in the following error:

  File "ppo_lstm_encoder_sample.py", line 104, in run
    result = algo.train()
             ^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 328, in train
    result = self.step()
             ^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 873, in step
    train_results, train_iter_ctx = self._run_one_training_iteration()
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 3154, in _run_one_training_iteration
    results = self.training_step()
              ^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 428, in training_step
    return self._training_step_old_and_hybrid_api_stacks()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 578, in _training_step_old_and_hybrid_api_stacks
    train_results = self.learner_group.update_from_batch(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py", line 267, in update_from_batch
    return self._update(
           ^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py", line 373, in _update
    _learner_update(
  File "lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py", line 350, in _learner_update
    return learner.update_from_batch(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/core/learner/learner.py", line 1114, in update_from_batch
    return self._update_from_batch_or_episodes(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_learner.py", line 68, in _update_from_batch_or_episodes
    return super()._update_from_batch_or_episodes(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/core/learner/learner.py", line 1367, in _update_from_batch_or_episodes
    for tensor_minibatch in batch_iter(batch, minibatch_size, num_iters):
  File "lib/python3.11/site-packages/ray/rllib/utils/minibatch_utils.py", line 145, in __iter__
    minibatch[module_id] = concat_samples(samples_to_concat)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py", line 1638, in concat_samples
    *[s[k] for s in concated_samples],
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py", line 1638, in <listcomp>
    *[s[k] for s in concated_samples],
      ~^^^
  File "lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py", line 951, in __getitem__
    value = dict.__getitem__(self, key)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'infos'

Seems like a regression from #44082.

Versions / Dependencies

ray version 2.23.0
Python 3.11

Reproduction script

import ray
from typing import Dict
from dataclasses import dataclass
from gymnasium import Space

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.examples.envs.classes.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.core.models.base import Encoder, ENCODER_OUT
from ray.rllib.core.models.configs import ModelConfig
from ray.rllib.core.models.torch.base import TorchModel
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.models.specs.specs_base import TensorSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.policy.sample_batch import SampleBatch
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env

torch, nn = try_import_torch()


class CustomTorchTokenizer(TorchModel, Encoder):
    def __init__(self, config) -> None:
        TorchModel.__init__(self, config)
        Encoder.__init__(self, config)
        self.net = nn.Sequential(
            nn.Linear(config.input_dims[0], config.output_dims[0]),
        )

    # Since we use this model as a tokenizer, we need to define it's output
    # dimensions so that we know the input dim for the recurent cells that follow.
    def get_output_specs(self):
        # In this example, the output dim will be 64, but we still fetch it from
        # config so that this code is more reusable.
        output_dim = self.config.output_dims[0]
        return SpecDict(
            {ENCODER_OUT: TensorSpec("b, d", d=output_dim, framework="torch")}
        )

    def _forward(self, inputs: dict, **kwargs):
        return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])}


@dataclass
class CustomTokenizerConfig(ModelConfig):
    output_dims: tuple = None

    def build(self, framework):
        if framework == "torch":
            return CustomTorchTokenizer(self)
        else:
            raise ValueError(f"Unsupported framework `{framework}`!")


class CustomPPOCatalog(PPOCatalog):
    @classmethod
    def get_tokenizer_config(
        cls,
        observation_space: Space,
        model_config_dict: Dict,
        view_requirements: Dict[str, ViewRequirement] | None = None,
    ) -> ModelConfig:
        return CustomTokenizerConfig(
            input_dims=observation_space.shape,
            output_dims=(64,),
        )


def run():
    ray.init()
    register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))

    model_config_dict = {
        "use_lstm": True,
        "lstm_cell_size": 32,
        "vf_share_layers": True,
    }
    rlm_spec = SingleAgentRLModuleSpec(
        catalog_class=CustomPPOCatalog,
        model_config_dict=model_config_dict,
    )

    algo = (
        PPOConfig()
        .experimental(_disable_preprocessor_api=True)
        .api_stack(enable_rl_module_and_learner=True)
        .resources(num_gpus=0)
        .env_runners(num_env_runners=0)
        .environment(
            env="RepeatAfterMeEnv",
            env_config={"continuous": True, "repeat_delay": 2},
        )
        .training(
            mini_batch_size_per_learner=8,
            train_batch_size=8,
        )
        .rl_module(rl_module_spec=rlm_spec)
        .build()
    )

    for i in range(100):
        result = algo.train()
        # print(pretty_print(result))
        print(
            f"Iteration {i+1:03d}: {result['env_runners']['episode_reward_mean']:.2f}"
        )

if __name__=="__main__":
    run()

Issue Severity

None

Metadata

Assignees

Labels

P2Important issue, but not time-criticalbugSomething that is supposed to be working; but isn'trllibRLlib related issuesrllib-modelsAn issue related to RLlib (default or custom) Models.rllib-newstack

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions