Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc type system fixes and refactoring #130

Merged
merged 21 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions phantom/agents.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from abc import ABC
from collections import defaultdict
from itertools import chain
Expand All @@ -6,6 +7,7 @@
Callable,
DefaultDict,
Dict,
Generic,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(

for name in dir(self):
if name not in ["observation_space", "action_space"]:
attr = getattr(self, name)
attr = inspect.getattr_static(self, name)
if callable(attr) and hasattr(attr, "_message_type"):
self.__handlers[attr._message_type].append(getattr(self, name))

Expand Down Expand Up @@ -178,7 +180,7 @@ def __repr__(self) -> str:
return f"[{self.__class__.__name__} {self.id}]"


class StrategicAgent(Agent):
class StrategicAgent(Agent, Generic[Action, Observation]):
"""
Representation of a behavioural agent in the network.

Expand Down Expand Up @@ -215,13 +217,8 @@ def __init__(

if action_decoder is not None:
self.action_space = action_decoder.action_space
elif "action_space" not in dir(self):
self.action_space = None

if observation_encoder is not None:
self.observation_space = observation_encoder.observation_space
elif "observation_space" not in dir(self):
self.observation_space = None

def encode_observation(self, ctx: Context) -> Observation:
"""
Expand Down
81 changes: 31 additions & 50 deletions phantom/env.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from typing import (
Any,
Dict,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
)

import gymnasium as gym
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Set, Tuple

from gymnasium.utils import seeding

from .agents import Agent, StrategicAgent
from .context import Context
Expand All @@ -22,10 +12,15 @@
from .views import AgentView, EnvView


class PhantomEnv(gym.Env):
class PhantomEnv:
"""
Base Phantom environment.

This class follows the gym/gymnasium environment and step paradigm. It does not
inherit from ``gym.Env`` as the API is different due to being a multi-agent
environment. It more closely aligns with the RLlib ```MultiAgentEnv`` class but
does not inherit from it so as not to be tied to RLlib.

Usage:
>>> env = PhantomEnv({ ... })
>>> env.reset()
Expand Down Expand Up @@ -80,40 +75,27 @@ def __init__(
# list contains only one reference to each sampler instance.
self._samplers: List[Sampler] = []

if env_supertype is not None:
if isinstance(env_supertype, dict):
env_supertype = self.Supertype(**env_supertype)
else:
assert isinstance(env_supertype, self.Supertype)

# The env will manage sampling the supertype values
env_supertype._managed = True
def build_supertype(supertype, entity) -> Supertype:
if isinstance(supertype, dict):
assert hasattr(entity, "Supertype")
supertype = entity.Supertype(**supertype)
elif hasattr(entity, "Supertype"):
assert isinstance(supertype, entity.Supertype)

# Extract samplers from env supertype dict
for value in env_supertype.__dict__.values():
for value in supertype.__dict__.values():
if isinstance(value, Sampler) and value not in self._samplers:
self._samplers.append(value)

self.env_supertype = env_supertype
supertype._managed = True
return supertype

if agent_supertypes is not None:
for agent_id, agent_supertype in agent_supertypes.items():
if isinstance(agent_supertype, dict):
agent_supertype = self.agents[agent_id].Supertype(**agent_supertype)
# TODO: fix, temporarily disabled as AgentClass.Supertype changed to __main__.Supertype
# else:
# assert isinstance(agent_supertype, self.agents[agent_id].Supertype)

# The env will manage sampling the supertype values
agent_supertype._managed = True

# Extract samplers from agent supertype dict
for value in agent_supertype.__dict__.values():
if isinstance(value, Sampler) and value not in self._samplers:
self._samplers.append(value)
if env_supertype is not None:
self.env_supertype = build_supertype(env_supertype, self)

agent = self.network.agents[agent_id]
agent.supertype = agent_supertype
for agent_id, agent_supertype in (agent_supertypes or {}).items():
agent = self.network.agents[agent_id]
agent.supertype = build_supertype(agent_supertype, agent)

# Generate initial sampled values in samplers
for sampler in self._samplers:
Expand Down Expand Up @@ -163,7 +145,7 @@ def non_strategic_agent_ids(self) -> List[AgentID]:
"""Return a list of the IDs of the agents that do not take actions."""
return [a.id for a in self.agents.values() if not isinstance(a, StrategicAgent)]

def view(self, agent_views: Dict[AgentID, AgentView]) -> EnvView:
def view(self, agent_views: Dict[AgentID, Optional[AgentView]]) -> EnvView:
"""Return an immutable view to the environment's public state."""
return EnvView(self.current_step, self.current_step / self.num_steps)

Expand Down Expand Up @@ -203,7 +185,8 @@ def reset(
"""
logger.log_reset()

super().reset(seed=seed, options=options)
if seed is not None:
self._np_random, seed = seeding.np_random(seed)

# Reset the clock
self._current_step = 0
Expand All @@ -224,7 +207,7 @@ def reset(
self._truncations = set()

# Generate all contexts for agents taking actions
self._make_ctxs(self.strategic_agent_ids)
self._ctxs = self._make_ctxs(self.strategic_agent_ids)

# Generate initial observations for agents taking actions
obs = {
Expand Down Expand Up @@ -256,7 +239,7 @@ def step(self, actions: Mapping[AgentID, Any]) -> "PhantomEnv.Step":
logger.log_start_decoding_actions()

# Generate contexts for all agents taking actions / generating messages
self._make_ctxs(self.agent_ids)
self._ctxs = self._make_ctxs(self.agent_ids)

# Decode action/generate messages for agents and send to the network
self._handle_acting_agents(self.agent_ids, actions)
Expand All @@ -275,6 +258,7 @@ def step(self, actions: Mapping[AgentID, Any]) -> "PhantomEnv.Step":
continue

ctx = self._ctxs[aid]
assert isinstance(ctx.agent, StrategicAgent)

obs = ctx.agent.encode_observation(ctx)
if obs is not None:
Expand Down Expand Up @@ -302,9 +286,6 @@ def step(self, actions: Mapping[AgentID, Any]) -> "PhantomEnv.Step":

return self.Step(observations, rewards, terminations, truncations, infos)

def render(self) -> None:
return None

def is_terminated(self) -> bool:
"""Implements the logic to decide when the episode is terminated."""
return len(self._terminations) == len(self.strategic_agents)
Expand Down Expand Up @@ -335,13 +316,13 @@ def _handle_acting_agents(
for receiver_id, message in messages:
self.network.send(aid, receiver_id, message)

def _make_ctxs(self, agent_ids: Sequence[AgentID]) -> None:
def _make_ctxs(self, agent_ids: Sequence[AgentID]) -> Dict[AgentID, Context]:
"""Internal method."""
env_view = self.view(
{agent_id: agent.view() for agent_id, agent in self.agents.items()}
)

self._ctxs = {
return {
aid: self.network.context_for(aid, env_view)
for aid in agent_ids
if aid not in self._terminations and aid not in self._truncations
Expand Down
4 changes: 2 additions & 2 deletions phantom/env_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import gymnasium as gym

from .agents import Agent, AgentID
from .agents import Agent, AgentID, StrategicAgent
from .env import PhantomEnv
from .policy import Policy

Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(
policies = list(other_policies.keys()) + [agent_id]

for agent in self._env.agents.values():
if agent.action_space is not None and agent.id not in policies:
if isinstance(agent, StrategicAgent) and agent.id not in policies:
raise ValueError(
f"Agent '{agent_id}' has not been defined a policy via the 'other_policies' parameter of SingleAgentEnvAdapter"
)
Expand Down
31 changes: 20 additions & 11 deletions phantom/fsm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple

from gymnasium.utils import seeding

from .agents import StrategicAgent
from .env import PhantomEnv
from .network import Network
from .supertype import Supertype
Expand Down Expand Up @@ -45,10 +48,10 @@ class FSMStage:
def __init__(
self,
stage_id: StageID,
acting_agents: Sequence[AgentID],
rewarded_agents: Optional[Sequence[AgentID]] = None,
next_stages: Optional[Sequence[StageID]] = None,
handler: Optional[Callable[[], StageID]] = None,
acting_agents: List[AgentID],
rewarded_agents: Optional[List[AgentID]] = None,
next_stages: Optional[List[StageID]] = None,
handler: Optional[Callable[..., StageID]] = None,
) -> None:
self.id = stage_id
self.acting_agents = acting_agents
Expand Down Expand Up @@ -186,7 +189,7 @@ def is_fsm_deterministic(self) -> bool:
"""Returns true if all stages are followed by exactly one stage."""
return all(len(s.next_stages) == 1 for s in self._stages.values())

def view(self, agent_views: Dict[AgentID, AgentView]) -> FSMEnvView:
def view(self, agent_views: Dict[AgentID, Optional[AgentView]]) -> FSMEnvView:
"""Return an immutable view to the FSM environment's public state."""
return FSMEnvView(
self.current_step, self.current_step / self.num_steps, self.current_stage
Expand All @@ -213,6 +216,9 @@ def reset(
"""
logger.log_reset()

if seed is not None:
self._np_random, seed = seeding.np_random(seed)

# Reset the clock and stage
self._current_step = 0
self._current_stage = self.initial_stage
Expand All @@ -236,7 +242,7 @@ def reset(

# Generate all contexts for agents taking actions
acting_agents = self._stages[self.current_stage].acting_agents
self._make_ctxs(
self._ctxs = self._make_ctxs(
[aid for aid in acting_agents if aid in self.strategic_agent_ids]
)

Expand Down Expand Up @@ -270,7 +276,7 @@ def step(self, actions: Mapping[AgentID, Any]) -> PhantomEnv.Step:
logger.log_start_decoding_actions()

# Generate contexts for all agents taking actions / generating messages
self._make_ctxs(self.agent_ids)
self._ctxs = self._make_ctxs(self.agent_ids)

# Decode action/generate messages for agents and send to the network
acting_agents = self._stages[self.current_stage].acting_agents
Expand Down Expand Up @@ -301,7 +307,9 @@ def step(self, actions: Mapping[AgentID, Any]) -> PhantomEnv.Step:
# function.
next_stage = env_handler(self)

if next_stage not in self._stages[self.current_stage].next_stages:
current_stage = self._stages[self.current_stage]

if next_stage not in current_stage.next_stages:
raise FSMRuntimeError(
f"FiniteStateMachineEnv attempted invalid transition from '{self.current_stage}' to {next_stage}"
)
Expand All @@ -312,18 +320,19 @@ def step(self, actions: Mapping[AgentID, Any]) -> PhantomEnv.Step:
truncations: Dict[AgentID, bool] = {}
infos: Dict[AgentID, Dict[str, Any]] = {}

if self._stages[self.current_stage].rewarded_agents is None:
if current_stage.rewarded_agents is None:
rewarded_agents = self.strategic_agent_ids
next_acting_agents = self.strategic_agent_ids
else:
rewarded_agents = self._stages[self.current_stage].rewarded_agents
rewarded_agents = current_stage.rewarded_agents
next_acting_agents = self._stages[next_stage].acting_agents

for aid in self.strategic_agent_ids:
if aid in self._terminations or aid in self._truncations:
continue

ctx = self._ctxs[aid]
assert isinstance(ctx.agent, StrategicAgent)

if aid in next_acting_agents:
obs = ctx.agent.encode_observation(ctx)
Expand Down
Loading