Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions maro/simulator/abs_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import ABC, abstractmethod
from enum import IntEnum
from typing import List
from typing import List, Optional, Tuple

from maro.backends.frame import SnapshotList
from maro.event_buffer import EventBuffer
Expand Down Expand Up @@ -46,26 +46,27 @@ def __init__(
disable_finished_events: bool,
options: dict
):
self._tick = start_tick
self._scenario = scenario
self._topology = topology
self._start_tick = start_tick
self._durations = durations
self._snapshot_resolution = snapshot_resolution
self._max_snapshots = max_snapshots
self._decision_mode = decision_mode
self._business_engine_cls = business_engine_cls
self._additional_options = options

self._business_engine: AbsBusinessEngine = None
self._event_buffer: EventBuffer = None
self._tick: int = start_tick
self._scenario: str = scenario
self._topology: str = topology
self._start_tick: int = start_tick
self._durations: int = durations
self._snapshot_resolution: int = snapshot_resolution
self._max_snapshots: int = max_snapshots
self._decision_mode: DecisionMode = decision_mode
self._business_engine_cls: type = business_engine_cls
self._disable_finished_events: bool = disable_finished_events
self._additional_options: dict = options

self._business_engine: Optional[AbsBusinessEngine] = None
self._event_buffer: Optional[EventBuffer] = None

@property
def business_engine(self):
def business_engine(self) -> AbsBusinessEngine:
return self._business_engine

@abstractmethod
def step(self, action):
def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]:
"""Push the environment to next step with action.

Args:
Expand All @@ -77,12 +78,12 @@ def step(self, action):
pass

@abstractmethod
def dump(self):
def dump(self) -> None:
"""Dump environment for restore."""
pass

@abstractmethod
def reset(self):
def reset(self) -> None:
"""Reset environment."""
pass

Expand Down Expand Up @@ -111,6 +112,7 @@ def tick(self) -> int:
pass

@property
@abstractmethod
def frame_index(self) -> int:
"""int: Frame index in snapshot list for current tick, USE this for snapshot querying."""
pass
Expand All @@ -127,7 +129,7 @@ def snapshot_list(self) -> SnapshotList:
"""SnapshotList: Current snapshot list, a snapshot list contains all the snapshots of frame at each tick."""
pass

def set_seed(self, seed: int):
def set_seed(self, seed: int) -> None:
"""Set random seed used by simulator.

NOTE:
Expand All @@ -147,10 +149,12 @@ def metrics(self) -> dict:
"""
return {}

@abstractmethod
def get_finished_events(self) -> list:
"""list: All events finished so far."""
pass

@abstractmethod
def get_pending_events(self, tick: int) -> list:
"""list: Pending events at certain tick.

Expand Down
48 changes: 24 additions & 24 deletions maro/simulator/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from collections import Iterable
from importlib import import_module
from inspect import getmembers, isclass
from typing import List
from typing import Generator, List, Optional, Tuple

from maro.backends.frame import FrameBase, SnapshotList
from maro.data_lib.dump_csv_converter import DumpConverter
from maro.event_buffer import EventBuffer, EventState
from maro.event_buffer import ActualEvent, CascadeEvent, EventBuffer, EventState
from maro.streamit import streamit
from maro.utils.exception.simulator_exception import BusinessEngineNotFoundError

from .abs_core import AbsEnv, DecisionMode
from .scenarios.abs_business_engine import AbsBusinessEngine
from .utils import seed as sim_seed
from .utils import random
from .utils.common import tick_to_frame_index


Expand Down Expand Up @@ -47,17 +47,16 @@ def __init__(
business_engine_cls: type = None, disable_finished_events: bool = False,
record_finished_events: bool = False,
record_file_path: str = None,
options: dict = {}
):
options: Optional[dict] = None
) -> None:
super().__init__(
scenario, topology, start_tick, durations,
snapshot_resolution, max_snapshots, decision_mode, business_engine_cls,
disable_finished_events, options
disable_finished_events, options if options is not None else {}
)

self._name = f'{self._scenario}:{self._topology}' if business_engine_cls is None \
else business_engine_cls.__name__
self._business_engine: AbsBusinessEngine = None

self._event_buffer = EventBuffer(disable_finished_events, record_finished_events, record_file_path)

Expand All @@ -72,12 +71,12 @@ def __init__(

if "enable-dump-snapshot" in self._additional_options:
parent_path = self._additional_options["enable-dump-snapshot"]
self._converter = DumpConverter(parent_path, self._business_engine._scenario_name)
self._converter = DumpConverter(parent_path, self._business_engine.scenario_name)
self._converter.reset_folder_path()

self._streamit_episode = 0

def step(self, action):
def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]:
"""Push the environment to next step with action.

Args:
Expand All @@ -93,15 +92,15 @@ def step(self, action):

return metrics, decision_event, _is_done

def dump(self):
def dump(self) -> None:
"""Dump environment for restore.

NOTE:
Not implemented.
"""
return

def reset(self, keep_seed: bool = False):
def reset(self, keep_seed: bool = False) -> None:
"""Reset environment.

Args:
Expand All @@ -114,10 +113,10 @@ def reset(self, keep_seed: bool = False):

self._event_buffer.reset()

if ("enable-dump-snapshot" in self._additional_options) and (self._business_engine._frame is not None):
if "enable-dump-snapshot" in self._additional_options and self._business_engine.frame is not None:
dump_folder = self._converter.get_new_snapshot_folder()

self._business_engine._frame.dump(dump_folder)
self._business_engine.frame.dump(dump_folder)
self._converter.start_processing(self.configs)
self._converter.dump_descsion_events(self._decision_events, self._start_tick, self._snapshot_resolution)
self._business_engine.dump(dump_folder)
Expand Down Expand Up @@ -173,7 +172,7 @@ def agent_idx_list(self) -> List[int]:
"""List[int]: Agent index list that related to this environment."""
return self._business_engine.get_agent_idx_list()

def set_seed(self, seed: int):
def set_seed(self, seed: int) -> None:
"""Set random seed used by simulator.

NOTE:
Expand All @@ -184,7 +183,7 @@ def set_seed(self, seed: int):
"""

if seed is not None:
sim_seed(seed)
random.seed(seed)

@property
def metrics(self) -> dict:
Expand All @@ -196,19 +195,19 @@ def metrics(self) -> dict:

return self._business_engine.get_metrics()

def get_finished_events(self):
def get_finished_events(self) -> List[ActualEvent]:
"""List[Event]: All events finished so far."""
return self._event_buffer.get_finished_events()

def get_pending_events(self, tick):
def get_pending_events(self, tick) -> List[ActualEvent]:
"""Pending events at certain tick.

Args:
tick (int): Specified tick to query.
"""
return self._event_buffer.get_pending_events(tick)

def _init_business_engine(self):
def _init_business_engine(self) -> None:
"""Initialize business engine object.

NOTE:
Expand Down Expand Up @@ -238,7 +237,7 @@ def _init_business_engine(self):
if business_class is None:
raise BusinessEngineNotFoundError()

self._business_engine = business_class(
self._business_engine: AbsBusinessEngine = business_class(
event_buffer=self._event_buffer,
topology=self._topology,
start_tick=self._start_tick,
Expand All @@ -248,10 +247,8 @@ def _init_business_engine(self):
additional_options=self._additional_options
)

def _simulate(self):
def _simulate(self) -> Generator[Tuple[dict, List[object], bool], object, None]:
"""This is the generator to wrap each episode process."""
is_end_tick = False

self._streamit_episode += 1

streamit.episode(self._streamit_episode)
Expand Down Expand Up @@ -297,8 +294,10 @@ def _simulate(self):

# NOTE: decision event always be a CascadeEvent
# We just append the action into sub event of first pending cascade event.
pending_events[0].state = EventState.EXECUTING
pending_events[0].add_immediate_event(action_event, is_head=True)
event = pending_events[0]
assert isinstance(event, CascadeEvent)
event.state = EventState.EXECUTING
event.add_immediate_event(action_event, is_head=True)
else:
# For joint mode, we will assign actions from beginning to end.
# Then mark others pending events to finished if not sequential action mode.
Expand All @@ -314,6 +313,7 @@ def _simulate(self):
pending_event.state = EventState.EXECUTING
action_event = self._event_buffer.gen_action_event(self._tick, action)

assert isinstance(pending_event, CascadeEvent)
pending_event.add_immediate_event(action_event, is_head=True)

# Check the end tick of the simulation to decide if we should end the simulation.
Expand Down
39 changes: 25 additions & 14 deletions maro/simulator/scenarios/abs_business_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional

from maro.backends.frame import FrameBase, SnapshotList
from maro.event_buffer import EventBuffer
Expand Down Expand Up @@ -31,23 +32,23 @@ class AbsBusinessEngine(ABC):
max_tick (int): Max tick of this business engine.
snapshot_resolution (int): Frequency to take a snapshot.
max_snapshots(int): Max number of in-memory snapshots, default is None that means max number of snapshots.
addition_options (dict): Additional options for this business engine from outside.
additional_options (dict): Additional options for this business engine from outside.
"""

def __init__(
self, scenario_name: str, event_buffer: EventBuffer, topology: str,
start_tick: int, max_tick: int, snapshot_resolution: int, max_snapshots: int,
additional_options: dict = None
):
self._scenario_name = scenario_name
self._topology = topology
self._event_buffer = event_buffer
self._start_tick = start_tick
self._max_tick = max_tick
self._snapshot_resolution = snapshot_resolution
self._max_snapshots = max_snapshots
self._additional_options = additional_options
self._config_path = None
self._scenario_name: str = scenario_name
self._topology: str = topology
self._event_buffer: EventBuffer = event_buffer
self._start_tick: int = start_tick
self._max_tick: int = max_tick
self._snapshot_resolution: int = snapshot_resolution
self._max_snapshots: int = max_snapshots
self._additional_options: dict = additional_options
self._config_path: Optional[str] = None

assert start_tick >= 0
assert max_tick > start_tick
Expand All @@ -65,6 +66,15 @@ def snapshots(self) -> SnapshotList:
"""SnapshotList: Snapshot list of current frame, this is used to expose querying interface for outside."""
pass

@property
def scenario_name(self) -> str:
return self._scenario_name

@abstractmethod
def get_agent_idx_list(self) -> List[int]:
"""Get a list of agent index."""
pass

def frame_index(self, tick: int) -> int:
"""Helper method for child class, used to get index of frame in snapshot list for specified tick.

Expand All @@ -89,7 +99,7 @@ def calc_max_snapshots(self) -> int:
return self._max_snapshots if self._max_snapshots is not None \
else total_frames(self._start_tick, self._max_tick, self._snapshot_resolution)

def update_config_root_path(self, business_engine_file_path: str):
def update_config_root_path(self, business_engine_file_path: str) -> None:
"""Helper method used to update the config path with business engine path if you
follow the way to load configuration file as built-in scenarios.

Expand Down Expand Up @@ -125,7 +135,7 @@ def __init__(self, *args, **kwargs):
self._config_path = os.path.join(be_file_path, "topologies", self._topology)

@abstractmethod
def step(self, tick: int):
def step(self, tick: int) -> None:
"""Method that is called at each tick, usually used to trigger business logic at current tick.

Args:
Expand All @@ -134,12 +144,13 @@ def step(self, tick: int):
pass

@property
@abstractmethod
def configs(self) -> dict:
"""dict: Configurations of this business engine."""
pass

@abstractmethod
def reset(self, keep_seed: bool = False):
def reset(self, keep_seed: bool = False) -> None:
"""Reset states business engine."""
pass

Expand Down Expand Up @@ -183,7 +194,7 @@ def get_metrics(self) -> dict:
"""
return {}

def dump(self, folder: str):
def dump(self, folder: str) -> None:
"""Dump something from business engine.

Args:
Expand Down