Skip to content

Commit

Permalink
[RLlib] Replace all Mapping typehints with Dict. (ray-project#46474)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Jul 8, 2024
1 parent 9ebc5d9 commit 8d02655
Show file tree
Hide file tree
Showing 18 changed files with 43 additions and 48 deletions.
4 changes: 2 additions & 2 deletions rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Mapping
from typing import Dict

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.dqn.torch.dqn_rainbow_torch_learner import (
Expand Down Expand Up @@ -108,7 +108,7 @@ def compute_loss_for_module(
module_id: ModuleID,
config: SACConfig,
batch: NestedDict,
fwd_out: Mapping[str, TensorType]
fwd_out: Dict[str, TensorType]
) -> TensorType:
# Only for debugging.
deterministic = config._deterministic_loss
Expand Down
4 changes: 2 additions & 2 deletions rllib/core/models/specs/checker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import logging
from collections import abc
from typing import Union, Mapping, Any, Callable
from typing import Any, Callable, Dict, Union

from ray.rllib.core.models.specs.specs_base import Spec, TypeSpec
from ray.rllib.core.models.specs.specs_dict import SpecDict
Expand Down Expand Up @@ -127,7 +127,7 @@ def _validate(
*,
cls_instance: object,
method: Callable,
data: Mapping[str, Any],
data: Dict[str, Any],
spec: Spec,
filter: bool = False,
tag: str = "input",
Expand Down
4 changes: 2 additions & 2 deletions rllib/core/models/specs/specs_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Mapping, Any
from typing import Any, Dict, Union

from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.nested_dict import NestedDict
Expand All @@ -18,7 +18,7 @@
"{} has type {} (expected type {})."
)

DATA_TYPE = Union[NestedDict[Any], Mapping[str, Any]]
DATA_TYPE = Union[NestedDict[Any], Dict[str, Any]]

IS_NOT_PROPERTY = "Spec {} must be a property of the class {}."

Expand Down
3 changes: 1 addition & 2 deletions rllib/core/rl_module/tf/tests/test_tf_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import tempfile
import unittest
from typing import Mapping

import gymnasium as gym
import tensorflow as tf
Expand Down Expand Up @@ -53,7 +52,7 @@ def test_forward_train(self):
)
loss = -tf.math.reduce_mean(action_dist.logp(actions))

self.assertIsInstance(output, Mapping)
self.assertIsInstance(output, dict)

grads = tape.gradient(loss, module.trainable_variables)

Expand Down
3 changes: 1 addition & 2 deletions rllib/core/rl_module/torch/tests/test_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import tempfile
import unittest
from typing import Mapping
import gc

import gymnasium as gym
Expand Down Expand Up @@ -48,7 +47,7 @@ def test_forward_train(self):
)
output = module.forward_train({"obs": obs})

self.assertIsInstance(output, Mapping)
self.assertIsInstance(output, dict)
self.assertIn(Columns.ACTION_DIST_INPUTS, output)

action_dist_inputs = output[Columns.ACTION_DIST_INPUTS]
Expand Down
4 changes: 2 additions & 2 deletions rllib/core/testing/tf/bc_learner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
from typing import Mapping, TYPE_CHECKING
from typing import Dict, TYPE_CHECKING

from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.tf.tf_learner import TfLearner
Expand All @@ -18,7 +18,7 @@ def compute_loss_for_module(
module_id: ModuleID,
config: "AlgorithmConfig",
batch: NestedDict,
fwd_out: Mapping[str, TensorType],
fwd_out: Dict[str, TensorType],
) -> TensorType:
BaseTestingLearner.compute_loss_for_module(
self,
Expand Down
14 changes: 7 additions & 7 deletions rllib/core/testing/tf/bc_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
from typing import Any, Mapping
from typing import Any, Dict

from ray.rllib.core.columns import Columns
from ray.rllib.core.models.specs.typing import SpecType
Expand Down Expand Up @@ -54,30 +54,30 @@ def output_specs_inference(self) -> SpecType:
def output_specs_train(self) -> SpecType:
return [Columns.ACTION_DIST_INPUTS]

def _forward_shared(self, batch: NestedDict) -> Mapping[str, Any]:
def _forward_shared(self, batch: NestedDict) -> Dict[str, Any]:
# We can use a shared forward method because BC does not need to distinguish
# between train, inference, and exploration.
action_logits = self.policy(batch["obs"])
return {Columns.ACTION_DIST_INPUTS: action_logits}

@override(RLModule)
def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:
def _forward_inference(self, batch: NestedDict) -> Dict[str, Any]:
return self._forward_shared(batch)

@override(RLModule)
def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]:
return self._forward_shared(batch)

@override(RLModule)
def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]:
def _forward_train(self, batch: NestedDict) -> Dict[str, Any]:
return self._forward_shared(batch)

@override(RLModule)
def get_state(self, inference_only: bool = False) -> Mapping[str, Any]:
def get_state(self, inference_only: bool = False) -> Dict[str, Any]:
return {"policy": self.policy.get_weights()}

@override(RLModule)
def set_state(self, state: Mapping[str, Any]) -> None:
def set_state(self, state: Dict[str, Any]) -> None:
self.policy.set_weights(state["policy"])


Expand Down
4 changes: 2 additions & 2 deletions rllib/core/testing/torch/bc_learner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Mapping, TYPE_CHECKING
from typing import Dict, TYPE_CHECKING

from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
Expand All @@ -18,7 +18,7 @@ def compute_loss_for_module(
module_id: ModuleID,
config: "AlgorithmConfig",
batch: NestedDict,
fwd_out: Mapping[str, TensorType],
fwd_out: Dict[str, TensorType],
) -> TensorType:
BaseTestingLearner.compute_loss_for_module(
self,
Expand Down
8 changes: 4 additions & 4 deletions rllib/core/testing/torch/bc_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping
from typing import Any, Dict

from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig
Expand Down Expand Up @@ -55,17 +55,17 @@ def output_specs_train(self) -> SpecType:
return [Columns.ACTION_DIST_INPUTS]

@override(RLModule)
def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:
def _forward_inference(self, batch: NestedDict) -> Dict[str, Any]:
with torch.no_grad():
return self._forward_train(batch)

@override(RLModule)
def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]:
with torch.no_grad():
return self._forward_train(batch)

@override(RLModule)
def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]:
def _forward_train(self, batch: NestedDict) -> Dict[str, Any]:
action_logits = self.policy(batch["obs"])
return {Columns.ACTION_DIST_INPUTS: action_logits}

Expand Down
4 changes: 2 additions & 2 deletions rllib/examples/learners/train_w_bc_finetune_w_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
This example shows how to pretrain an RLModule using behavioral cloning from offline
data and, thereafter, continue training it online with PPO (fine-tuning).
"""
from typing import Dict

import gymnasium as gym
import shutil
import tempfile
import torch
from typing import Mapping

import ray
from ray import tune
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
self.distribution_cls = distribution_cls

def forward(
self, batch: Mapping[str, torch.Tensor]
self, batch: Dict[str, torch.Tensor]
) -> torch.distributions.Distribution:
"""Return an action distribution output by the policy network.
Expand Down
4 changes: 2 additions & 2 deletions rllib/examples/rl_modules/classes/random_rlm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import Mapping, Any
from typing import Any, Dict

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -59,7 +59,7 @@ def from_model_config(
observation_space: gym.Space,
action_space: gym.Space,
*,
model_config_dict: Mapping[str, Any],
model_config_dict: Dict[str, Any],
) -> "RLModule":
return cls(action_space)

Expand Down
4 changes: 2 additions & 2 deletions rllib/models/tf/tf_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import gymnasium as gym
import tree
import numpy as np
from typing import Optional, List, Mapping, Iterable, Dict
from typing import Dict, Iterable, List, Optional
import abc


Expand Down Expand Up @@ -499,7 +499,7 @@ def required_input_dim(space: gym.Space, input_lens: List[int], **kwargs) -> int
def from_logits(
cls,
logits: tf.Tensor,
child_distribution_cls_struct: Union[Mapping, Iterable],
child_distribution_cls_struct: Union[Dict, Iterable],
input_lens: Union[Dict, List[int]],
space: gym.Space,
**kwargs,
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/torch/torch_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import gymnasium as gym
import numpy as np
from typing import Optional, List, Mapping, Iterable, Dict
from typing import Dict, Iterable, List, Optional
import tree
import abc

Expand Down Expand Up @@ -613,7 +613,7 @@ def required_input_dim(
def from_logits(
cls,
logits: torch.Tensor,
child_distribution_cls_struct: Union[Mapping, Iterable],
child_distribution_cls_struct: Union[Dict, Iterable],
input_lens: Union[Dict, List[int]],
space: gym.Space,
**kwargs,
Expand Down
5 changes: 2 additions & 3 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Collection,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -1811,7 +1810,7 @@ def __repr__(self):

@OldAPIStack
def get_gym_space_from_struct_of_tensors(
value: Union[Mapping, Tuple, List, TensorType],
value: Union[Dict, Tuple, List, TensorType],
batched_input=True,
) -> gym.Space:
start_idx = 1 if batched_input else 0
Expand All @@ -1827,7 +1826,7 @@ def get_gym_space_from_struct_of_tensors(

@OldAPIStack
def get_gym_space_from_struct_of_spaces(value: Union[Dict, Tuple]) -> gym.spaces.Dict:
if isinstance(value, Mapping):
if isinstance(value, dict):
return gym.spaces.Dict(
{k: get_gym_space_from_struct_of_spaces(v) for k, v in value.items()}
)
Expand Down
7 changes: 3 additions & 4 deletions rllib/policy/sample_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,8 +1315,8 @@ class MultiAgentBatch:
"""A batch of experiences from multiple agents in the environment.
Attributes:
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatches of experiences.
policy_batches (Dict[PolicyID, SampleBatch]): Dict mapping policy IDs to
SampleBatches of experiences.
count: The number of env steps in this batch.
"""

Expand All @@ -1325,8 +1325,7 @@ def __init__(self, policy_batches: Dict[PolicyID, SampleBatch], env_steps: int):
"""Initialize a MultiAgentBatch instance.
Args:
policy_batches: Mapping from policy
ids to SampleBatches of experiences.
policy_batches: Dict mapping policy IDs to SampleBatches of experiences.
env_steps: The number of environment steps in the environment
this batch contains. This will be less than the number of
transitions this batch contains across all policies in total.
Expand Down
8 changes: 4 additions & 4 deletions rllib/utils/actor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import sys
import time
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import ray
from ray.actor import ActorHandle
Expand Down Expand Up @@ -261,14 +261,14 @@ def __init__(
self._next_id = init_id

# Actors are stored in a map and indexed by a unique (int) ID.
self._actors: Mapping[int, ActorHandle] = {}
self._remote_actor_states: Mapping[int, self._ActorState] = {}
self._actors: Dict[int, ActorHandle] = {}
self._remote_actor_states: Dict[int, self._ActorState] = {}
self._restored_actors = set()
self.add_actors(actors or [])

# Maps outstanding async requests to the IDs of the actor IDs that
# are executing them.
self._in_flight_req_to_actor_id: Mapping[ray.ObjectRef, int] = {}
self._in_flight_req_to_actor_id: Dict[ray.ObjectRef, int] = {}

self._max_remote_requests_in_flight_per_actor = (
max_remote_requests_in_flight_per_actor
Expand Down
4 changes: 2 additions & 2 deletions rllib/utils/debug/summary.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import pprint
from typing import Any, Mapping
from typing import Any

from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import DeveloperAPI
Expand All @@ -26,7 +26,7 @@ def summarize(obj: Any) -> Any:


def _summarize(obj):
if isinstance(obj, Mapping):
if isinstance(obj, dict):
return {k: _summarize(v) for k, v in obj.items()}
elif hasattr(obj, "_asdict"):
return {
Expand Down
3 changes: 1 addition & 2 deletions rllib/utils/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -183,7 +182,7 @@ def local_policy_inference(
reward: Optional[float] = None,
terminated: Optional[bool] = None,
truncated: Optional[bool] = None,
info: Optional[Mapping] = None,
info: Optional[Dict] = None,
explore: bool = None,
timestep: Optional[int] = None,
) -> TensorStructType:
Expand Down

0 comments on commit 8d02655

Please sign in to comment.