Skip to content

Commit b6e70ce

Browse files
author
Chris Elion
authored
cleanup some mypy types (#5072)
1 parent 0a5092b commit b6e70ce

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ml-agents/mlagents/trainers/coma/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
PPOSettings, self.trainer_settings.hyperparameters
5858
)
5959
self.seed = seed
60-
self.policy: Policy = None # type: ignore
60+
self.policy: TorchPolicy = None # type: ignore
6161
self.collected_group_rewards: Dict[str, int] = defaultdict(lambda: 0)
6262

6363
def _process_trajectory(self, trajectory: Trajectory) -> None:
@@ -264,9 +264,7 @@ def create_torch_policy(
264264
return policy
265265

266266
def create_coma_optimizer(self) -> TorchCOMAOptimizer:
267-
return TorchCOMAOptimizer( # type: ignore
268-
cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore
269-
) # type: ignore
267+
return TorchCOMAOptimizer(self.policy, self.trainer_settings)
270268

271269
def add_policy(
272270
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy
@@ -276,6 +274,8 @@ def add_policy(
276274
:param parsed_behavior_id: Behavior identifiers that the policy should belong to.
277275
:param policy: Policy to associate with name_behavior_id.
278276
"""
277+
if not isinstance(policy, TorchPolicy):
278+
raise RuntimeError(f"policy {policy} must be an instance of TorchPolicy.")
279279
self.policy = policy
280280
self.policies[parsed_behavior_id.behavior_id] = policy
281281
self.optimizer = self.create_coma_optimizer()

0 commit comments

Comments
 (0)