Skip to content

Commit

Permalink
Quickly validate envs when using RLlib (#128)
Browse files Browse the repository at this point in the history
* Quickly validate envs when using RLlib

* Fix quick validation when using telemetry

* Add validate method to environment classes
  • Loading branch information
jaredvann authored Dec 12, 2023
1 parent fe9bb80 commit 1122754
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 33 deletions.
9 changes: 9 additions & 0 deletions phantom/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,15 @@ def is_truncated(self) -> bool:

return is_at_max_step or len(self._truncations) == len(self.strategic_agents)

def validate(self) -> None:
"""
Validate the environment by executing a number of steps that sufficiently covers
the features of the environment.
"""
obs, _ = self.reset()
actions = {aid: self.agents[aid].action_space.sample() for aid in obs}
self.step(actions)

def _handle_acting_agents(
self, agent_ids: Sequence[AgentID], actions: Mapping[AgentID, Any]
) -> None:
Expand Down
14 changes: 14 additions & 0 deletions phantom/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,17 @@ def step(self, actions: Mapping[AgentID, Any]) -> PhantomEnv.Step:
rewards = {aid: self._rewards[aid] for aid in observations}

return self.Step(observations, rewards, terminations, truncations, infos)

def validate(self) -> None:
"""
Validate the environment by executing a number of steps that sufficiently covers
the features of the environment.
"""
obs, _ = self.reset()

for _ in range(self.num_steps):
actions = {aid: self.agents[aid].action_space.sample() for aid in obs}
obs, _, done, _, _ = self.step(actions)

if done["__all__"]:
break
14 changes: 14 additions & 0 deletions phantom/stackelberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,17 @@ def step(self, actions: Mapping[AgentID, Any]) -> PhantomEnv.Step:
}

return self.Step(observations, rewards, terminations, truncations, infos)

def validate(self) -> None:
"""
Validate the environment by executing a number of steps that sufficiently covers
the features of the environment.
"""
obs, _ = self.reset()

for _ in range(2):
actions = {aid: self.agents[aid].action_space.sample() for aid in obs}
obs, _, done, _, _ = self.step(actions)

if done["__all__"]:
break
34 changes: 29 additions & 5 deletions phantom/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(self) -> None:

self._current_episode = None

self._paused: bool = False

def configure_print_logging(
self,
enable: Union[bool, None] = None,
Expand Down Expand Up @@ -337,9 +339,13 @@ def log_msg_recv(self, message: Message) -> None:

def log_metrics(self, env: "PhantomEnv") -> None:
if self._current_episode is not None:
self._current_episode["steps"][-1]["metrics"] = {
name: metric.extract(env) for name, metric in self._file_metrics.items()
}
if self._file_metrics is not None:
self._current_episode["steps"][-1]["metrics"] = {
name: metric.extract(env)
for name, metric in self._file_metrics.items()
}
else:
self._current_episode["steps"][-1]["metrics"] = {}

if self._enable_print and self._print_metrics is not None:
print(_t(1) + colored("METRICS:", color="cyan"))
Expand All @@ -356,6 +362,16 @@ def log_episode_done(self) -> None:
if self._enable_print:
print(_t(1) + colored("EPISODE DONE", color="green", attrs=["bold"]))

def pause(self) -> "TelemetryLogger.PauseContextManager":
class PauseContextManager:
def __enter__(self2):
self._paused = True

def __exit__(self2, exc_type, exc_val, exc_tb):
self._paused = False

return PauseContextManager()

def _print_msg(self, message: Message, string: str) -> None:
if self._should_print_msg(message):
route_str = f"{message.sender_id: >10} --> {message.receiver_id: <10}"
Expand All @@ -369,7 +385,11 @@ def _print_msg(self, message: Message, string: str) -> None:
)

def _write_episode_to_file(self) -> None:
if self._output_file is not None and self._current_episode is not None:
if (
not self._paused
and self._output_file is not None
and self._current_episode is not None
):
json.dump(
self._current_episode,
self._output_file,
Expand All @@ -382,7 +402,8 @@ def _write_episode_to_file(self) -> None:

def _should_print_msg(self, message: Message) -> bool:
return (
self._enable_print
not self._paused
and self._enable_print
and self._print_messages
and (
isinstance(self._print_messages, bool)
Expand Down Expand Up @@ -417,6 +438,9 @@ class NumpyArrayEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, np.ndarray):
return o.tolist()
if isinstance(o, np.number):
return o.item()

return json.JSONEncoder.default(self, o)


Expand Down
17 changes: 6 additions & 11 deletions phantom/utils/rllib/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Dict,
Generator,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Type, Union

import cloudpickle
import ray
Expand All @@ -28,6 +18,7 @@
from ...metrics import Metric, logging_helper
from ...policy import Policy
from ...types import AgentID
from ... import telemetry
from ..rollout import Rollout, Step
from .. import (
collect_instances_of_type_with_paths,
Expand Down Expand Up @@ -187,6 +178,10 @@ def rollout(
"Cannot use non-determinisic FSM when policy_inference_batch_size > 1"
)

with telemetry.logger.pause():
env.validate()
env.reset()

num_workers_ = (os.cpu_count() - 1) if num_workers is None else num_workers

print(
Expand Down
22 changes: 9 additions & 13 deletions phantom/utils/rllib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,7 @@
from datetime import datetime
from inspect import isclass
from pathlib import Path
from typing import (
Any,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union

import cloudpickle
import gymnasium as gym
Expand All @@ -31,6 +22,7 @@
from ...metrics import Metric, logging_helper
from ...policy import Policy
from ...types import AgentID
from ... import telemetry
from .. import check_env_config, rich_progress, show_pythonhashseed_warning
from .wrapper import RLlibEnvWrapper

Expand Down Expand Up @@ -133,10 +125,12 @@ class will use the same fixed/learnt policy.

check_env_config(env_config)

ray.init(ignore_reinit_error=True, **(ray_config or {}))

env = env_class(**env_config)
env.reset()
with telemetry.logger.pause():
env.validate()
env.reset()

ray.init(ignore_reinit_error=True, **(ray_config or {}))

policy_specs: Dict[str, rllib.policy.policy.PolicySpec] = {}
policy_mapping: Dict[AgentID, str] = {}
Expand Down Expand Up @@ -307,6 +301,8 @@ def __call__(self) -> "RLlibMetricLogger":


def make_rllib_wrapped_policy_class(policy_class: Type[Policy]) -> Type[rllib.Policy]:
"""Internal function"""

class RLlibPolicyWrapper(rllib.Policy):
# NOTE:
# If the action space is larger than -1.0 < x < 1.0, RLlib will attempt to
Expand Down
11 changes: 7 additions & 4 deletions scripts/view_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,14 @@ def load_data(file: str):
if infos == 0:
st.text("None")

if show_metrics and len(step["metrics"]) > 0:
if show_metrics:
st.subheader("Metrics:")
df = pd.DataFrame(step["metrics"].items())
df.columns = ["Metric", "Value"]
st.table(df)
if len(step["metrics"]) > 0:
df = pd.DataFrame(step["metrics"].items())
df.columns = ["Metric", "Value"]
st.table(df)
else:
st.text("None")

if show_fsm and "fsm_current_stage" in step and "fsm_next_stage" in step:
st.subheader("FSM Transition:")
Expand Down

0 comments on commit 1122754

Please sign in to comment.