Skip to content

Commit 273b164

Browse files
committed
Complement docstring for task queue and trainer.
1 parent c036a45 commit 273b164

File tree

4 files changed

+52
-3
lines changed

4 files changed

+52
-3
lines changed

maro/rl/data_parallelism/task_queue.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def submit(
4040
self, worker_id_list: List, batch_list: List, tensor_dict_list: List, policy_state: Dict, policy_name: str,
4141
scope: str = None
4242
) -> Dict[str, List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]]:
43-
"""Learn a batch of data on several grad workers."""
43+
"""Learn a batch of data on several grad workers.
44+
For each policy, send a list of batch and state to grad workers, and receive a list of gradients.
45+
The results is actually from train worker's `get_batch_grad()` method, with type:
46+
Dict[str, Dict[int, Dict[str, torch.Tensor]]], which means {scope: {worker_id: {param_name: grad_value}}}"""
4447
msg_dict = defaultdict(lambda: defaultdict(dict))
4548
loss_info_by_policy = {policy_name: []}
4649
for worker_id, batch, tensor_dict in zip(worker_id_list, batch_list, tensor_dict_list):
@@ -80,6 +83,21 @@ def task_queue(
8083
proxy_kwargs: dict = {},
8184
logger: Logger = DummyLogger()
8285
):
86+
"""The queue to manage data parallel tasks. Task queue communicates with gradient workers,
87+
maintaing the busy/idle status of workers. Clients send requests to task queue, and task queue
88+
will assign available workers to the requests. Task queue follows the `producer-consumer` model,
89+
consisting of two queues: task_pending, task_assigned. Besides, task queue supports task priority,
90+
adding/deleting workers.
91+
92+
Args:
93+
worker_ids (List[int]): Worker ids to initialize.
94+
num_hosts (int): The number of policy hosts. Will be renamed in RL v3.
95+
num_policies (int): The number of policies.
96+
single_task_limit (float): The limit resource proportion for a single task to assign. Defaults to 0.5
97+
group (str): Group name to initialize proxy. Defaults to DEFAULT_POLICY_GROUP.
98+
proxy_kwargs (dict): Keyword arguments for proxy. Defaults to empty dict.
99+
logger (Logger): Defaults to DummyLogger().
100+
"""
83101
num_workers = len(worker_ids)
84102
if num_hosts == 0:
85103
# for multi-process mode

maro/rl_v3/policy_trainer/abs_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111

1212
class AbsTrainer(object, metaclass=ABCMeta):
13-
"""
14-
Policy trainer used to train policies.
13+
"""Policy trainer used to train policies. Trainer maintains several train workers and
14+
controls training logics of them, while train workers take charge of specific policy updating.
1515
"""
1616
def __init__(
1717
self,

maro/rl_v3/policy_trainer/distributed_discrete_maddpg.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,25 @@
1313

1414

1515
class DiscreteMADDPGWorker(MultiTrainWorker):
16+
"""The discrete variant of MADDPG algorithm.
17+
Args:
18+
name (str): Name of the worker.
19+
device (torch.device): Which device to use.
20+
reward_discount (float): The discount factor of feature reward.
21+
get_q_critic_net_func (Callable[[], MultiQNet): Function to get Q critic net.
22+
shared_critic (bool): Whether to share critic for actors. Defaults to False.
23+
critic_loss_coef (float): Coefficient for critic loss in total loss. Defaults to 1.0.
24+
soft_update_coef (float): Soft update coefficient, e.g., target_model = (soft_update_coef) * eval_model +
25+
(1-soft_update_coef) * target_model. Defaults to 1.0.
26+
update_target_every (int): Number of training rounds between policy target model updates. Defaults to 5.
27+
q_value_loss_func (Callable): The loss function provided by torch.nn or a custom loss class for the
28+
Q-value loss. Defaults to None.
29+
enable_data_parallelism (bool): Whether to enable data parallelism in this trainer. Defaults to False.
30+
31+
Reference:
32+
Paper: http://papers.nips.cc/paper/by-source-2017-3193
33+
Code: https://github.com/openai/maddpg
34+
"""
1635
def __init__(
1736
self,
1837
name: str,

maro/rl_v3/policy_trainer/train_worker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111

1212

1313
class AbsTrainWorker(object, metaclass=ABCMeta):
14+
"""The basic component for training a policy, which mainly takes charge of gradient computation and policy update.
15+
In trainer, train worker hosts a policy, and trainer hosts several train workers. In gradient workers,
16+
the train worker is an atomic representation of a policy, to perform parallel gradient computing.
17+
"""
1418
def __init__(
1519
self,
1620
name: str,
@@ -45,6 +49,11 @@ def _remote_learn(
4549
tensor_dict: Dict[str, object] = None,
4650
scope: str = "all"
4751
) -> List[Dict[str, Dict[int, Dict[str, torch.Tensor]]]]:
52+
"""Learn a batch of experience data from remote gradient workers.
53+
The task queue client will first request available gradient workers from task queue. If all workers are busy,
54+
it will keep waiting until at least 1 worker is available. Then the task queue client submits batch and state
55+
to the assigned workers to compute gradients.
56+
"""
4857
assert self._task_queue_client is not None
4958
worker_id_list = self._task_queue_client.request_workers()
5059
batch_list = self._dispatch_batch(batch, len(worker_id_list))
@@ -67,6 +76,9 @@ def get_batch_grad(
6776

6877
@abstractmethod
6978
def _dispatch_batch(self, batch: MultiTransitionBatch, num_workers: int) -> List[MultiTransitionBatch]:
79+
"""Split experience data batch to several parts.
80+
For on-policy algorithms, like PG, the batch is splitted into several complete trajectories.
81+
For off-policy algorithms, like DQN, the batch is treated as independent data points and splitted evenly."""
7082
raise NotImplementedError
7183

7284
@abstractmethod

0 commit comments

Comments
 (0)