|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +import os |
| 5 | +import sys |
| 6 | + |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from maro.rl.learning.env_sampler_v2 import AbsEnvSampler |
| 10 | +from maro.simulator import Env |
| 11 | +from maro.simulator.scenarios.cim.common import Action, ActionType |
| 12 | + |
| 13 | +cim_path = os.path.dirname(os.path.realpath(__file__)) |
| 14 | +if cim_path not in sys.path: |
| 15 | + sys.path.insert(0, cim_path) |
| 16 | + |
| 17 | +from config import ( |
| 18 | + action_shaping_conf, algorithm, env_conf, port_attributes, reward_shaping_conf, state_shaping_conf, |
| 19 | + vessel_attributes |
| 20 | +) |
| 21 | +from policies import policy_func_dict |
| 22 | + |
| 23 | + |
| 24 | +class CIMEnvSampler(AbsEnvSampler): |
| 25 | + def get_state(self, tick=None): |
| 26 | + """ |
| 27 | + The state vector includes shortage and remaining vessel space over the past k days (where k is the "look_back" |
| 28 | + value in ``state_shaping_conf``), as well as all downstream port features. |
| 29 | + """ |
| 30 | + if tick is None: |
| 31 | + tick = self._env.tick |
| 32 | + vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"] |
| 33 | + port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx |
| 34 | + ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)] |
| 35 | + future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int') |
| 36 | + state = np.concatenate([ |
| 37 | + port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes], |
| 38 | + vessel_snapshots[tick : vessel_idx : vessel_attributes] |
| 39 | + ]) |
| 40 | + return {port_idx: state} |
| 41 | + |
| 42 | + def get_env_actions(self, action_by_agent): |
| 43 | + """ |
| 44 | + The policy output is an integer from [0, 20] which is to be interpreted as the index of ``action_space`` in |
| 45 | + ``action_shaping_conf``. For example, action 5 corresponds to -0.5, which means loading 50% of the containers |
| 46 | + available at the current port to the vessel, while action 18 corresponds to 0.8, which means loading 80% of the |
| 47 | + containers on the vessel to the port. Note that action 10 corresponds 0.0, which means doing nothing. |
| 48 | + """ |
| 49 | + action_space = action_shaping_conf["action_space"] |
| 50 | + finite_vsl_space = action_shaping_conf["finite_vessel_space"] |
| 51 | + has_early_discharge = action_shaping_conf["has_early_discharge"] |
| 52 | + |
| 53 | + port_idx, action = list(action_by_agent.items()).pop() |
| 54 | + vsl_idx, action_scope = self.event.vessel_idx, self.event.action_scope |
| 55 | + vsl_snapshots = self._env.snapshot_list["vessels"] |
| 56 | + vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf") |
| 57 | + |
| 58 | + model_action = action["action"] if isinstance(action, dict) else action |
| 59 | + percent = abs(action_space[model_action]) |
| 60 | + zero_action_idx = len(action_space) / 2 # index corresponding to value zero. |
| 61 | + if model_action < zero_action_idx: |
| 62 | + action_type = ActionType.LOAD |
| 63 | + actual_action = min(round(percent * action_scope.load), vsl_space) |
| 64 | + elif model_action > zero_action_idx: |
| 65 | + action_type = ActionType.DISCHARGE |
| 66 | + early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0 |
| 67 | + plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge |
| 68 | + actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge) |
| 69 | + else: |
| 70 | + actual_action, action_type = 0, None |
| 71 | + |
| 72 | + return [Action(port_idx=port_idx, vessel_idx=vsl_idx, quantity=actual_action, action_type=action_type)] |
| 73 | + |
| 74 | + def get_reward(self, actions, tick): |
| 75 | + """ |
| 76 | + The reward is defined as a linear combination of fulfillment and shortage measures. The fulfillment and |
| 77 | + shortage measures are the sums of fulfillment and shortage values over the next k days, respectively, each |
| 78 | + adjusted with exponential decay factors (using the "time_decay" value in ``reward_shaping_conf``) to put more |
| 79 | + emphasis on the near future. Here k is the "time_window" value in ``reward_shaping_conf``. The linear |
| 80 | + combination coefficients are given by "fulfillment_factor" and "shortage_factor" in ``reward_shaping_conf``. |
| 81 | + """ |
| 82 | + start_tick = tick + 1 |
| 83 | + ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"])) |
| 84 | + |
| 85 | + # Get the ports that took actions at the given tick |
| 86 | + ports = [action.port_idx for action in actions] |
| 87 | + port_snapshots = self._env.snapshot_list["ports"] |
| 88 | + future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1) |
| 89 | + future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1) |
| 90 | + |
| 91 | + decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])] |
| 92 | + rewards = np.float32( |
| 93 | + reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list) |
| 94 | + - reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list) |
| 95 | + ) |
| 96 | + return {agent_id: reward for agent_id, reward in zip(ports, rewards)} |
| 97 | + |
| 98 | + def post_step(self, state, action, env_action, reward, tick): |
| 99 | + """ |
| 100 | + The environment sampler contains a "tracker" dict inherited from the "AbsEnvSampler" base class, which can |
| 101 | + be used to record any information one wishes to keep track of during a roll-out episode. Here we simply record |
| 102 | + the latest env metric without keeping the history for logging purposes. |
| 103 | + """ |
| 104 | + self._tracker["env_metric"] = self._env.metrics |
| 105 | + |
| 106 | + |
| 107 | +agent2policy = {agent: f"{algorithm}.{agent}" for agent in Env(**env_conf).agent_idx_list} |
| 108 | + |
| 109 | +def get_env_sampler(): |
| 110 | + return CIMEnvSampler( |
| 111 | + get_env=lambda: Env(**env_conf), |
| 112 | + get_policy_func_dict=policy_func_dict, |
| 113 | + agent2policy=agent2policy, |
| 114 | + reward_eval_delay=reward_shaping_conf["time_window"], |
| 115 | + parallel_inference=False |
| 116 | + ) |
0 commit comments