Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions maro/rl_v3/learning/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def set_policy_states(self, policy_state_dict: Dict[str, object]) -> None:
"""
for policy_name, policy_state in policy_state_dict.items():
policy = self._policy_dict[policy_name]
policy.set_policy_state(policy_state)
policy.set_state(policy_state)

def choose_actions(self, state_by_agent: Dict[Any, np.ndarray]) -> Dict[Any, np.ndarray]:
"""
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
self,
get_env_func: Callable[[], Env],
#
get_policy_func_dict: Dict[str, Callable[[str], RLPolicy]],
policy_creator: Dict[str, Callable[[str], RLPolicy]],
agent2policy: Dict[Any, str], # {agent_name: policy_name}
agent_wrapper_cls: Type[AbsAgentWrapper],
reward_eval_delay: int = 0,
Expand All @@ -161,7 +161,7 @@ def __init__(
"""
Args:
get_env_func (Dict[str, Callable[[str], RLPolicy]]): Dict of functions used to create the learning Env.
get_policy_func_dict (Dict[str, Callable[[str], RLPolicy]]): Dict of functions used to create policies.
policy_creator (Dict[str, Callable[[str], RLPolicy]]): Dict of functions used to create policies.
agent2policy (Dict[Any, str]): Agent name to policy name mapping.
agent_wrapper_cls (Type[AbsAgentWrapper]): Concrete AbsAgentWrapper type.
reward_eval_delay (int): Number of ticks required after a decision event to evaluate the reward
Expand All @@ -178,7 +178,7 @@ def __init__(
else torch.device("cuda" if torch.cuda.is_available() else "cpu")

self._policy_dict: Dict[str, RLPolicy] = {
policy_name: func(policy_name) for policy_name, func in get_policy_func_dict.items()
policy_name: func(policy_name) for policy_name, func in policy_creator.items()
}
self._agent_wrapper = agent_wrapper_cls(self._policy_dict, agent2policy)
self._agent2policy = agent2policy
Expand Down
4 changes: 2 additions & 2 deletions maro/rl_v3/policy/abs_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,14 @@ def train(self) -> None:
raise NotImplementedError

@abstractmethod
def get_policy_state(self) -> object:
def get_state(self) -> object:
"""
Get the state of the policy.
"""
raise NotImplementedError

@abstractmethod
def set_policy_state(self, policy_state: object) -> None:
def set_state(self, policy_state: object) -> None:
"""
Set the state of the policy.
"""
Expand Down
4 changes: 2 additions & 2 deletions maro/rl_v3/policy/continuous_rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def eval(self) -> None:
def train(self) -> None:
self._policy_net.train()

def get_policy_state(self) -> object:
def get_state(self) -> object:
return self._policy_net.get_net_state()

def set_policy_state(self, policy_state: object) -> None:
def set_state(self, policy_state: object) -> None:
self._policy_net.set_net_state(policy_state)

def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
Expand Down
8 changes: 4 additions & 4 deletions maro/rl_v3/policy/discrete_rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def eval(self) -> None:
def train(self) -> None:
self._q_net.train()

def get_policy_state(self) -> object:
def get_state(self) -> object:
return self._q_net.get_net_state()

def set_policy_state(self, policy_state: object) -> None:
def set_state(self, policy_state: object) -> None:
self._q_net.set_net_state(policy_state)

def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
Expand Down Expand Up @@ -182,10 +182,10 @@ def eval(self) -> None:
def train(self) -> None:
self._policy_net.train()

def get_policy_state(self) -> object:
def get_state(self) -> object:
return self._policy_net.get_net_state()

def set_policy_state(self, policy_state: object) -> None:
def set_state(self, policy_state: object) -> None:
self._policy_net.set_net_state(policy_state)

def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
Expand Down
4 changes: 2 additions & 2 deletions maro/rl_v3/tmp_example_multi/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .config import (
action_shaping_conf, env_conf, port_attributes, reward_shaping_conf, state_shaping_conf, vessel_attributes
)
from .policies import get_policy_func_dict
from .policies import policy_creator


class CIMEnvSampler(AbsEnvSampler):
Expand Down Expand Up @@ -78,7 +78,7 @@ def _post_step(self, cache_element: CacheElement, reward: Dict[Any, float]) -> N
algorithm = "dqn"
env_sampler = CIMEnvSampler(
get_env_func=lambda: Env(**env_conf),
get_policy_func_dict=get_policy_func_dict,
policy_creator=policy_creator,
agent2policy={agent: f"{algorithm}.{agent}" for agent in Env(**env_conf).agent_idx_list},
agent_wrapper_cls=SimpleAgentWrapper,
)
9 changes: 5 additions & 4 deletions maro/rl_v3/tmp_example_multi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@
from .callbacks import cim_post_collect, cim_post_evaluate
from .config import algorithm, env_conf, running_mode
from .env_sampler import CIMEnvSampler
from .policies import get_policy_func_dict, get_trainer_func_dict
from .policies import policy_creator, trainer_creator


if __name__ == "__main__":
run_workflow_centralized_mode(
get_env_sampler_func=lambda: CIMEnvSampler(
get_env_func=lambda: Env(**env_conf),
get_policy_func_dict=get_policy_func_dict,
policy_creator=policy_creator,
agent2policy={agent: f"{algorithm}_{agent}.{agent}" for agent in Env(**env_conf).agent_idx_list},
agent_wrapper_cls=SimpleAgentWrapper,
device="cpu"
),
get_trainer_manager_func=lambda: SimpleTrainerManager(
get_policy_func_dict=get_policy_func_dict,
policy_creator=policy_creator,
agent2policy={agent: f"{algorithm}_{agent}.{agent}" for agent in Env(**env_conf).agent_idx_list},
get_trainer_func_dict=get_trainer_func_dict
trainer_creator=trainer_creator
),
num_episodes=30,
post_collect=cim_post_collect,
Expand Down
13 changes: 6 additions & 7 deletions maro/rl_v3/tmp_example_multi/policies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from maro.rl_v3.policy import DiscretePolicyGradient
from maro.rl_v3.training.algorithms import DiscreteMADDPG, DiscreteMADDPGParams
from maro.rl_v3.workflow import preprocess_get_policy_func_dict
from maro.rl_v3.workflow import preprocess_policy_creator

from .config import algorithm, running_mode
from .nets import MyActorNet, MyMultiCriticNet
Expand Down Expand Up @@ -29,20 +29,19 @@ def get_maddpg(name: str) -> DiscreteMADDPG:


if algorithm == "discrete_maddpg":
get_policy_func = get_discrete_policy_gradient
get_policy_func_dict = {
f"{algorithm}_{i}.{i}": get_policy_func
policy_creator = {
f"{algorithm}_{i}.{i}": get_discrete_policy_gradient
for i in range(4)
}

get_trainer_func_dict = {
trainer_creator = {
f"{algorithm}_{i}": get_maddpg
for i in range(4)
}
else:
raise ValueError
# #####################################################################################################################

get_policy_func_dict = preprocess_get_policy_func_dict(
get_policy_func_dict, running_mode
policy_creator = preprocess_policy_creator(
policy_creator, running_mode
)
2 changes: 1 addition & 1 deletion maro/rl_v3/tmp_example_single/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@
)

algorithm = "ac"
running_mode = "centralized"
running_mode = "decentralized"
4 changes: 2 additions & 2 deletions maro/rl_v3/tmp_example_single/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .config import (
action_shaping_conf, env_conf, port_attributes, reward_shaping_conf, state_shaping_conf, vessel_attributes
)
from .policies import get_policy_func_dict
from .policies import policy_creator


class CIMEnvSampler(AbsEnvSampler):
Expand Down Expand Up @@ -78,7 +78,7 @@ def _post_step(self, cache_element: CacheElement, reward: Dict[Any, float]) -> N
algorithm = "dqn"
env_sampler = CIMEnvSampler(
get_env_func=lambda: Env(**env_conf),
get_policy_func_dict=get_policy_func_dict,
policy_creator=policy_creator,
agent2policy={agent: f"{algorithm}.{agent}" for agent in Env(**env_conf).agent_idx_list},
agent_wrapper_cls=SimpleAgentWrapper,
)
8 changes: 4 additions & 4 deletions maro/rl_v3/tmp_example_single/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@
from .callbacks import cim_post_collect, cim_post_evaluate
from .config import algorithm, env_conf, running_mode
from .env_sampler import CIMEnvSampler
from .policies import get_policy_func_dict, get_trainer_func_dict
from .policies import policy_creator, trainer_creator

if __name__ == "__main__":
run_workflow_centralized_mode(
get_env_sampler_func=lambda: CIMEnvSampler(
get_env_func=lambda: Env(**env_conf),
get_policy_func_dict=get_policy_func_dict,
policy_creator=policy_creator,
agent2policy={agent: f"{algorithm}_{agent}.{agent}" for agent in Env(**env_conf).agent_idx_list},
agent_wrapper_cls=SimpleAgentWrapper,
device="cpu"
),
get_trainer_manager_func=lambda: SimpleTrainerManager(
get_policy_func_dict=get_policy_func_dict,
policy_creator=policy_creator,
agent2policy={agent: f"{algorithm}_{agent}.{agent}" for agent in Env(**env_conf).agent_idx_list},
get_trainer_func_dict=get_trainer_func_dict
trainer_creator=trainer_creator
),
num_episodes=30,
post_collect=cim_post_collect,
Expand Down
20 changes: 9 additions & 11 deletions maro/rl_v3/tmp_example_single/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl_v3.policy import DiscretePolicyGradient, ValueBasedPolicy
from maro.rl_v3.training.algorithms import DQN, DiscreteActorCritic, DiscreteActorCriticParams, DQNParams
from maro.rl_v3.workflow import preprocess_get_policy_func_dict
from maro.rl_v3.workflow import preprocess_policy_creator

from .config import algorithm, running_mode
from .nets import MyActorNet, MyCriticNet, MyQNet
Expand Down Expand Up @@ -60,32 +60,30 @@ def get_ac(name: str) -> DiscreteActorCritic:


if algorithm == "dqn":
get_policy_func = get_value_based_policy
get_policy_func_dict = {
f"{algorithm}_{i}.{i}": get_policy_func
policy_creator = {
f"{algorithm}_{i}.{i}": get_value_based_policy
for i in range(4)
}

get_trainer_func_dict = {
trainer_creator = {
f"{algorithm}_{i}": get_dqn
for i in range(4)
}

elif algorithm == "ac":
get_policy_func = get_discrete_policy_gradient
get_policy_func_dict = {
f"{algorithm}_{i}.{i}": get_policy_func
policy_creator = {
f"{algorithm}_{i}.{i}": get_discrete_policy_gradient
for i in range(4)
}

get_trainer_func_dict = {
trainer_creator = {
f"{algorithm}_{i}": get_ac
for i in range(4)
}
else:
raise ValueError
# #####################################################################################################################

get_policy_func_dict = preprocess_get_policy_func_dict(
get_policy_func_dict, running_mode
policy_creator = preprocess_policy_creator(
policy_creator, running_mode
)
13 changes: 5 additions & 8 deletions maro/rl_v3/training/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,14 @@ def update(self, grad_iters: int) -> None:
def get_state_dict(self, scope: str = "all") -> dict:
ret_dict = {}
if scope in ("all", "actor"):
ret_dict["policy_state"] = self._policy.get_policy_state()
ret_dict["policy_state"] = self._policy.get_state()
if scope in ("all", "critic"):
ret_dict["critic_state"] = self._v_critic_net.get_net_state()
return ret_dict

def set_state_dict(self, ops_state_dict: dict, scope: str = "all") -> None:
if scope in ("all", "actor"):
self._policy.set_policy_state(ops_state_dict["policy_state"])
self._policy.set_state(ops_state_dict["policy_state"])
if scope in ("all", "critic"):
self._v_critic_net.set_net_state(ops_state_dict["critic_state"])

Expand All @@ -195,27 +195,24 @@ class DiscreteActorCritic(SingleTrainer):
"""
def __init__(self, name: str, params: DiscreteActorCriticParams) -> None:
super(DiscreteActorCritic, self).__init__(name, params)

self._params = params
self._ops_name = f"{self._name}.ops"

def build(self) -> None:
self._ops_params = {
"get_policy_func": self._get_policy_func,
**self._params.extract_ops_params(),
}

self._ops = self.get_ops(f"{self.name}_ops")
self._ops = self.get_ops(self._ops_name)
self._replay_memory = FIFOReplayMemory(
capacity=self._params.replay_memory_capacity,
state_dim=self._ops.policy_state_dim,
action_dim=self._ops.policy_action_dim
)

def _get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
if ops_name == f"{self.name}_ops":
return DiscreteActorCriticOps(**self._ops_params)
else:
raise ValueError(f"Unknown ops name {ops_name}")
return DiscreteActorCriticOps(**self._ops_params)

async def train_step(self):
await asyncio.gather(self._ops.set_batch(self._get_batch()))
Expand Down
Loading