@@ -40,7 +40,10 @@ def submit(
4040 self , worker_id_list : List , batch_list : List , tensor_dict_list : List , policy_state : Dict , policy_name : str ,
4141 scope : str = None
4242 ) -> Dict [str , List [Dict [str , Dict [int , Dict [str , torch .Tensor ]]]]]:
43- """Learn a batch of data on several grad workers."""
43+ """Learn a batch of data on several grad workers.
44+ For each policy, send a list of batch and state to grad workers, and receive a list of gradients.
45+ The results is actually from train worker's `get_batch_grad()` method, with type:
46+ Dict[str, Dict[int, Dict[str, torch.Tensor]]], which means {scope: {worker_id: {param_name: grad_value}}}"""
4447 msg_dict = defaultdict (lambda : defaultdict (dict ))
4548 loss_info_by_policy = {policy_name : []}
4649 for worker_id , batch , tensor_dict in zip (worker_id_list , batch_list , tensor_dict_list ):
@@ -80,6 +83,21 @@ def task_queue(
8083 proxy_kwargs : dict = {},
8184 logger : Logger = DummyLogger ()
8285):
86+ """The queue to manage data parallel tasks. Task queue communicates with gradient workers,
87+ maintaing the busy/idle status of workers. Clients send requests to task queue, and task queue
88+ will assign available workers to the requests. Task queue follows the `producer-consumer` model,
89+ consisting of two queues: task_pending, task_assigned. Besides, task queue supports task priority,
90+ adding/deleting workers.
91+
92+ Args:
93+ worker_ids (List[int]): Worker ids to initialize.
94+ num_hosts (int): The number of policy hosts. Will be renamed in RL v3.
95+ num_policies (int): The number of policies.
96+ single_task_limit (float): The limit resource proportion for a single task to assign. Defaults to 0.5
97+ group (str): Group name to initialize proxy. Defaults to DEFAULT_POLICY_GROUP.
98+ proxy_kwargs (dict): Keyword arguments for proxy. Defaults to empty dict.
99+ logger (Logger): Defaults to DummyLogger().
100+ """
83101 num_workers = len (worker_ids )
84102 if num_hosts == 0 :
85103 # for multi-process mode
0 commit comments