@@ -57,7 +57,7 @@ def __init__(
57
57
PPOSettings , self .trainer_settings .hyperparameters
58
58
)
59
59
self .seed = seed
60
- self .policy : Policy = None # type: ignore
60
+ self .policy : TorchPolicy = None # type: ignore
61
61
self .collected_group_rewards : Dict [str , int ] = defaultdict (lambda : 0 )
62
62
63
63
def _process_trajectory (self , trajectory : Trajectory ) -> None :
@@ -264,9 +264,7 @@ def create_torch_policy(
264
264
return policy
265
265
266
266
def create_coma_optimizer (self ) -> TorchCOMAOptimizer :
267
- return TorchCOMAOptimizer ( # type: ignore
268
- cast (TorchPolicy , self .policy ), self .trainer_settings # type: ignore
269
- ) # type: ignore
267
+ return TorchCOMAOptimizer (self .policy , self .trainer_settings )
270
268
271
269
def add_policy (
272
270
self , parsed_behavior_id : BehaviorIdentifiers , policy : Policy
@@ -276,6 +274,8 @@ def add_policy(
276
274
:param parsed_behavior_id: Behavior identifiers that the policy should belong to.
277
275
:param policy: Policy to associate with name_behavior_id.
278
276
"""
277
+ if not isinstance (policy , TorchPolicy ):
278
+ raise RuntimeError (f"policy { policy } must be an instance of TorchPolicy." )
279
279
self .policy = policy
280
280
self .policies [parsed_behavior_id .behavior_id ] = policy
281
281
self .optimizer = self .create_coma_optimizer ()
0 commit comments