1212
1313from maro .rl .rollout import AbsEnvSampler , CacheElement
1414from maro .simulator import Env
15- from maro .simulator .scenarios .vm_scheduling import AllocateAction , DecisionPayload , PostponeAction
15+ from maro .simulator .scenarios .vm_scheduling import AllocateAction , DecisionEvent , PostponeAction
1616
1717from .config import (
1818 num_features ,
@@ -44,7 +44,7 @@ def __init__(self, learn_env: Env, test_env: Env) -> None:
4444
4545 def _get_global_and_agent_state_impl (
4646 self ,
47- event : DecisionPayload ,
47+ event : DecisionEvent ,
4848 tick : int = None ,
4949 ) -> Tuple [Union [None , np .ndarray , List [object ]], Dict [Any , Union [np .ndarray , List [object ]]]]:
5050 pm_state , vm_state = self ._get_pm_state (), self ._get_vm_state (event )
@@ -71,14 +71,14 @@ def _get_global_and_agent_state_impl(
7171 def _translate_to_env_action (
7272 self ,
7373 action_dict : Dict [Any , Union [np .ndarray , List [object ]]],
74- event : DecisionPayload ,
74+ event : DecisionEvent ,
7575 ) -> Dict [Any , object ]:
7676 if action_dict ["AGENT" ] == self .num_pms :
7777 return {"AGENT" : PostponeAction (vm_id = event .vm_id , postpone_step = 1 )}
7878 else :
7979 return {"AGENT" : AllocateAction (vm_id = event .vm_id , pm_id = action_dict ["AGENT" ][0 ])}
8080
81- def _get_reward (self , env_action_dict : Dict [Any , object ], event : DecisionPayload , tick : int ) -> Dict [Any , float ]:
81+ def _get_reward (self , env_action_dict : Dict [Any , object ], event : DecisionEvent , tick : int ) -> Dict [Any , float ]:
8282 action = env_action_dict ["AGENT" ]
8383 conf = reward_shaping_conf if self ._env == self ._learn_env else test_reward_shaping_conf
8484 if isinstance (action , PostponeAction ): # postponement
@@ -121,7 +121,7 @@ def _get_vm_state(self, event):
121121 ],
122122 )
123123
124- def _get_allocation_reward (self , event : DecisionPayload , alpha : float , beta : float ):
124+ def _get_allocation_reward (self , event : DecisionEvent , alpha : float , beta : float ):
125125 vm_unit_price = self ._env .business_engine ._get_unit_price (
126126 event .vm_cpu_cores_requirement ,
127127 event .vm_memory_requirement ,
0 commit comments