Skip to content

Commit 6fa3fed

Browse files
committed
Drafi v2.0 for V2
1 parent 3e38eb7 commit 6fa3fed

28 files changed

+3805
-85
lines changed

examples/rl/cim/env_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_state(self, tick=None):
3232
vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
3333
port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx
3434
ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
35-
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
35+
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
3636
state = np.concatenate([
3737
port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes],
3838
vessel_snapshots[tick : vessel_idx : vessel_attributes]
@@ -55,7 +55,7 @@ def get_env_actions(self, action_by_agent):
5555
vsl_snapshots = self.env.snapshot_list["vessels"]
5656
vsl_space = vsl_snapshots[self.env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")
5757

58-
model_action = action["action"] if isinstance(action, dict) else action
58+
model_action = action["action"] if isinstance(action, dict) else action
5959
percent = abs(action_space[model_action])
6060
zero_action_idx = len(action_space) / 2 # index corresponding to value zero.
6161
if model_action < zero_action_idx:
@@ -112,5 +112,5 @@ def get_env_sampler():
112112
get_policy_func_dict=policy_func_dict,
113113
agent2policy=agent2policy,
114114
reward_eval_delay=reward_shaping_conf["time_window"],
115-
parallel_inference=True
115+
parallel_inference=False
116116
)

examples/rl/cim_v2/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Container Inventory Management
2+
3+
This example demonstrates the use of MARO's RL toolkit to optimize container inventory management. The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time. In this folder you can find:
4+
* ``config.py``, which contains environment and policy configurations for the scenario;
5+
* ``env_sampler.py``, which defines state, action and reward shaping in the ``CIMEnvSampler`` class;
6+
* ``policies.py``, which defines the Q-net for DQN and the network components for Actor-Critic;
7+
* ``callbacks.py``, which defines routines to be invoked at the end of training or evaluation episodes.
8+
9+
The scripts for running the learning workflows can be found under ``examples/rl/workflows``. See ``README`` under ``examples/rl`` for details about the general applicability of these scripts. We recommend that you follow this example to write your own scenarios.

examples/rl/cim_v2/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
from .callbacks import post_collect, post_evaluate
5+
from .env_sampler import agent2policy, get_env_sampler
6+
from .policies import policy_func_dict
7+
8+
__all__ = ["agent2policy", "post_collect", "post_evaluate", "get_env_sampler", "policy_func_dict"]

examples/rl/cim_v2/callbacks.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import time
5+
from os import makedirs
6+
from os.path import dirname, join, realpath
7+
8+
log_dir = join(dirname(realpath(__file__)), "log", str(time.time()))
9+
makedirs(log_dir, exist_ok=True)
10+
11+
12+
def post_collect(trackers, ep, segment):
13+
# print the env metric from each rollout worker
14+
for tracker in trackers:
15+
print(f"env summary (episode {ep}, segment {segment}): {tracker['env_metric']}")
16+
17+
# print the average env metric
18+
if len(trackers) > 1:
19+
metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers)
20+
avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys}
21+
print(f"average env summary (episode {ep}, segment {segment}): {avg_metric}")
22+
23+
24+
def post_evaluate(trackers, ep):
25+
# print the env metric from each rollout worker
26+
for tracker in trackers:
27+
print(f"env summary (episode {ep}): {tracker['env_metric']}")
28+
29+
# print the average env metric
30+
if len(trackers) > 1:
31+
metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers)
32+
avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys}
33+
print(f"average env summary (episode {ep}): {avg_metric}")

examples/rl/cim_v2/config.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import torch
5+
from torch.optim import Adam, RMSprop
6+
7+
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
8+
9+
10+
env_conf = {
11+
"scenario": "cim",
12+
"topology": "toy.4p_ssdd_l0.0",
13+
"durations": 560
14+
}
15+
16+
port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
17+
vessel_attributes = ["empty", "full", "remaining_space"]
18+
19+
state_shaping_conf = {
20+
"look_back": 7,
21+
"max_ports_downstream": 2
22+
}
23+
24+
action_shaping_conf = {
25+
"action_space": [(i - 10) / 10 for i in range(21)],
26+
"finite_vessel_space": True,
27+
"has_early_discharge": True
28+
}
29+
30+
reward_shaping_conf = {
31+
"time_window": 99,
32+
"fulfillment_factor": 1.0,
33+
"shortage_factor": 1.0,
34+
"time_decay": 0.97
35+
}
36+
37+
# obtain state dimension from a temporary env_wrapper instance
38+
state_dim = (
39+
(state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
40+
+ len(vessel_attributes)
41+
)
42+
43+
############################################## POLICIES ###############################################
44+
45+
algorithm = "ac"
46+
47+
# DQN settings
48+
q_net_conf = {
49+
"input_dim": state_dim,
50+
"hidden_dims": [256, 128, 64, 32],
51+
"output_dim": len(action_shaping_conf["action_space"]),
52+
"activation": torch.nn.LeakyReLU,
53+
"softmax": False,
54+
"batch_norm": True,
55+
"skip_connection": False,
56+
"head": True,
57+
"dropout_p": 0.0
58+
}
59+
60+
q_net_optim_conf = (RMSprop, {"lr": 0.05})
61+
62+
dqn_conf = {
63+
"reward_discount": .0,
64+
"update_target_every": 5,
65+
"num_epochs": 10,
66+
"soft_update_coef": 0.1,
67+
"double": False,
68+
"exploration_strategy": (epsilon_greedy, {"epsilon": 0.4}),
69+
"exploration_scheduling_options": [(
70+
"epsilon", MultiLinearExplorationScheduler, {
71+
"splits": [(2, 0.32)],
72+
"initial_value": 0.4,
73+
"last_ep": 5,
74+
"final_value": 0.0,
75+
}
76+
)],
77+
"replay_memory_capacity": 10000,
78+
"random_overwrite": False,
79+
"warmup": 100,
80+
"rollout_batch_size": 128,
81+
"train_batch_size": 32,
82+
# "prioritized_replay_kwargs": {
83+
# "alpha": 0.6,
84+
# "beta": 0.4,
85+
# "beta_step": 0.001,
86+
# "max_priority": 1e8
87+
# }
88+
}
89+
90+
91+
# AC settings
92+
actor_net_conf = {
93+
"input_dim": state_dim,
94+
"hidden_dims": [256, 128, 64],
95+
"output_dim": len(action_shaping_conf["action_space"]),
96+
"activation": torch.nn.Tanh,
97+
"softmax": True,
98+
"batch_norm": False,
99+
"head": True
100+
}
101+
102+
critic_net_conf = {
103+
"input_dim": state_dim,
104+
"hidden_dims": [256, 128, 64],
105+
"output_dim": 1,
106+
"activation": torch.nn.LeakyReLU,
107+
"softmax": False,
108+
"batch_norm": True,
109+
"head": True
110+
}
111+
112+
actor_optim_conf = (Adam, {"lr": 0.001})
113+
critic_optim_conf = (RMSprop, {"lr": 0.001})
114+
115+
ac_conf = {
116+
"reward_discount": .0,
117+
"grad_iters": 10,
118+
"critic_loss_cls": torch.nn.SmoothL1Loss,
119+
"min_logp": None,
120+
"critic_loss_coef": 0.1,
121+
"entropy_coef": 0.01,
122+
# "clip_ratio": 0.8 # for PPO
123+
"lam": .0,
124+
"get_loss_on_rollout": False
125+
}

examples/rl/cim_v2/env_sampler.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import os
5+
import sys
6+
7+
import numpy as np
8+
9+
from maro.rl.learning.env_sampler_v2 import AbsEnvSampler
10+
from maro.simulator import Env
11+
from maro.simulator.scenarios.cim.common import Action, ActionType
12+
13+
cim_path = os.path.dirname(os.path.realpath(__file__))
14+
if cim_path not in sys.path:
15+
sys.path.insert(0, cim_path)
16+
17+
from config import (
18+
action_shaping_conf, algorithm, env_conf, port_attributes, reward_shaping_conf, state_shaping_conf,
19+
vessel_attributes
20+
)
21+
from policies import policy_func_dict
22+
23+
24+
class CIMEnvSampler(AbsEnvSampler):
25+
def get_state(self, tick=None):
26+
"""
27+
The state vector includes shortage and remaining vessel space over the past k days (where k is the "look_back"
28+
value in ``state_shaping_conf``), as well as all downstream port features.
29+
"""
30+
if tick is None:
31+
tick = self._env.tick
32+
vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"]
33+
port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx
34+
ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
35+
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
36+
state = np.concatenate([
37+
port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes],
38+
vessel_snapshots[tick : vessel_idx : vessel_attributes]
39+
])
40+
return {port_idx: state}
41+
42+
def get_env_actions(self, action_by_agent):
43+
"""
44+
The policy output is an integer from [0, 20] which is to be interpreted as the index of ``action_space`` in
45+
``action_shaping_conf``. For example, action 5 corresponds to -0.5, which means loading 50% of the containers
46+
available at the current port to the vessel, while action 18 corresponds to 0.8, which means loading 80% of the
47+
containers on the vessel to the port. Note that action 10 corresponds 0.0, which means doing nothing.
48+
"""
49+
action_space = action_shaping_conf["action_space"]
50+
finite_vsl_space = action_shaping_conf["finite_vessel_space"]
51+
has_early_discharge = action_shaping_conf["has_early_discharge"]
52+
53+
port_idx, action = list(action_by_agent.items()).pop()
54+
vsl_idx, action_scope = self.event.vessel_idx, self.event.action_scope
55+
vsl_snapshots = self._env.snapshot_list["vessels"]
56+
vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")
57+
58+
model_action = action["action"] if isinstance(action, dict) else action
59+
percent = abs(action_space[model_action])
60+
zero_action_idx = len(action_space) / 2 # index corresponding to value zero.
61+
if model_action < zero_action_idx:
62+
action_type = ActionType.LOAD
63+
actual_action = min(round(percent * action_scope.load), vsl_space)
64+
elif model_action > zero_action_idx:
65+
action_type = ActionType.DISCHARGE
66+
early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
67+
plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
68+
actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
69+
else:
70+
actual_action, action_type = 0, None
71+
72+
return [Action(port_idx=port_idx, vessel_idx=vsl_idx, quantity=actual_action, action_type=action_type)]
73+
74+
def get_reward(self, actions, tick):
75+
"""
76+
The reward is defined as a linear combination of fulfillment and shortage measures. The fulfillment and
77+
shortage measures are the sums of fulfillment and shortage values over the next k days, respectively, each
78+
adjusted with exponential decay factors (using the "time_decay" value in ``reward_shaping_conf``) to put more
79+
emphasis on the near future. Here k is the "time_window" value in ``reward_shaping_conf``. The linear
80+
combination coefficients are given by "fulfillment_factor" and "shortage_factor" in ``reward_shaping_conf``.
81+
"""
82+
start_tick = tick + 1
83+
ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"]))
84+
85+
# Get the ports that took actions at the given tick
86+
ports = [action.port_idx for action in actions]
87+
port_snapshots = self._env.snapshot_list["ports"]
88+
future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1)
89+
future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1)
90+
91+
decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])]
92+
rewards = np.float32(
93+
reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list)
94+
- reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list)
95+
)
96+
return {agent_id: reward for agent_id, reward in zip(ports, rewards)}
97+
98+
def post_step(self, state, action, env_action, reward, tick):
99+
"""
100+
The environment sampler contains a "tracker" dict inherited from the "AbsEnvSampler" base class, which can
101+
be used to record any information one wishes to keep track of during a roll-out episode. Here we simply record
102+
the latest env metric without keeping the history for logging purposes.
103+
"""
104+
self._tracker["env_metric"] = self._env.metrics
105+
106+
107+
agent2policy = {agent: f"{algorithm}.{agent}" for agent in Env(**env_conf).agent_idx_list}
108+
109+
def get_env_sampler():
110+
return CIMEnvSampler(
111+
get_env=lambda: Env(**env_conf),
112+
get_policy_func_dict=policy_func_dict,
113+
agent2policy=agent2policy,
114+
reward_eval_delay=reward_shaping_conf["time_window"],
115+
parallel_inference=False
116+
)

0 commit comments

Comments
 (0)