|
1 | 1 | # # Unity ML-Agents Toolkit
|
2 | 2 | # ## ML-Agent Learning (Ghost Trainer)
|
3 | 3 |
|
4 |
| -from typing import Deque, Dict, List, cast |
| 4 | +from collections import defaultdict |
| 5 | +from typing import Deque, Dict, DefaultDict, List, cast |
5 | 6 |
|
6 | 7 | import numpy as np
|
7 | 8 |
|
@@ -68,9 +69,9 @@ def __init__(
|
68 | 69 | self._internal_trajectory_queues: Dict[str, AgentManagerQueue[Trajectory]] = {}
|
69 | 70 | self._internal_policy_queues: Dict[str, AgentManagerQueue[Policy]] = {}
|
70 | 71 |
|
71 |
| - self._team_to_name_to_policy_queue: Dict[ |
| 72 | + self._team_to_name_to_policy_queue: DefaultDict[ |
72 | 73 | int, Dict[str, AgentManagerQueue[Policy]]
|
73 |
| - ] = {} |
| 74 | + ] = defaultdict(dict) |
74 | 75 |
|
75 | 76 | self._name_to_parsed_behavior_id: Dict[str, BehaviorIdentifiers] = {}
|
76 | 77 |
|
@@ -413,14 +414,9 @@ def publish_policy_queue(self, policy_queue: AgentManagerQueue[Policy]) -> None:
|
413 | 414 | """
|
414 | 415 | super().publish_policy_queue(policy_queue)
|
415 | 416 | parsed_behavior_id = self._name_to_parsed_behavior_id[policy_queue.behavior_id]
|
416 |
| - try: |
417 |
| - self._team_to_name_to_policy_queue[parsed_behavior_id.team_id][ |
418 |
| - parsed_behavior_id.brain_name |
419 |
| - ] = policy_queue |
420 |
| - except KeyError: |
421 |
| - self._team_to_name_to_policy_queue[parsed_behavior_id.team_id] = { |
422 |
| - parsed_behavior_id.brain_name: policy_queue |
423 |
| - } |
| 417 | + self._team_to_name_to_policy_queue[parsed_behavior_id.team_id][ |
| 418 | + parsed_behavior_id.brain_name |
| 419 | + ] = policy_queue |
424 | 420 | if parsed_behavior_id.team_id == self.wrapped_trainer_team:
|
425 | 421 | # With a future multiagent trainer, this will be indexed by 'role'
|
426 | 422 | internal_policy_queue: AgentManagerQueue[Policy] = AgentManagerQueue(
|
|
0 commit comments