@@ -57,9 +57,10 @@ def fill_invalid(obj: int | float | bool | np.ndarray | dict | list | tuple) ->
5757
5858def is_invalid (arr : int | float | bool | np .ndarray | dict | list | tuple ) -> bool :
5959 if hasattr (arr , "dtype" ):
60- if np .issubdtype (arr .dtype , np .floating ):
60+ dtype = getattr (arr , "dtype" )
61+ if np .issubdtype (dtype , np .floating ):
6162 return np .isnan (arr ).all ()
62- return (np .iinfo (arr . dtype ).max == arr ).all ()
63+ return (np .iinfo (dtype ).max == arr ).all ()
6364 if isinstance (arr , dict ):
6465 return all (is_invalid (o ) for o in arr .values ())
6566 if isinstance (arr , (list , tuple )):
@@ -209,7 +210,7 @@ def collector_guard(self) -> Generator[FiniteVectorEnv, None, None]:
209210
210211 def reset (
211212 self ,
212- id : int | List [int ] | np .ndarray = None ,
213+ id : int | List [int ] | np .ndarray | None = None ,
213214 ) -> np .ndarray :
214215 assert not self ._zombie
215216
@@ -222,23 +223,23 @@ def reset(
222223 RuntimeWarning ,
223224 )
224225
225- id = self ._wrap_id (id )
226+ wrapped_id = self ._wrap_id (id )
226227 self ._reset_alive_envs ()
227228
228229 # ask super to reset alive envs and remap to current index
229- request_id = list ( filter ( lambda i : i in self ._alive_env_ids , id ))
230- obs = [None ] * len (id )
231- id2idx = {i : k for k , i in enumerate (id )}
230+ request_id = [ i for i in wrapped_id if i in self ._alive_env_ids ]
231+ obs = [None ] * len (wrapped_id )
232+ id2idx = {i : k for k , i in enumerate (wrapped_id )}
232233 if request_id :
233234 for i , o in zip (request_id , super ().reset (request_id )):
234235 obs [id2idx [i ]] = self ._postproc_env_obs (o )
235236
236- for i , o in zip (id , obs ):
237+ for i , o in zip (wrapped_id , obs ):
237238 if o is None and i in self ._alive_env_ids :
238239 self ._alive_env_ids .remove (i )
239240
240241 # logging
241- for i , o in zip (id , obs ):
242+ for i , o in zip (wrapped_id , obs ):
242243 if i in self ._alive_env_ids :
243244 for logger in self ._logger :
244245 logger .on_env_reset (i , obs )
@@ -251,7 +252,7 @@ def reset(
251252 obs [i ] = self ._get_default_obs ()
252253
253254 if not self ._alive_env_ids :
254- # comment this line so that the env becomes indisposable
255+ # comment this line so that the env becomes indispensable
255256 # self.reset()
256257 self ._zombie = True
257258 raise StopIteration
@@ -261,13 +262,13 @@ def reset(
261262 def step (
262263 self ,
263264 action : np .ndarray ,
264- id : int | List [int ] | np .ndarray = None ,
265+ id : int | List [int ] | np .ndarray | None = None ,
265266 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
266267 assert not self ._zombie
267- id = self ._wrap_id (id )
268- id2idx = {i : k for k , i in enumerate (id )}
269- request_id = list (filter (lambda i : i in self ._alive_env_ids , id ))
270- result = [[None , None , False , None ] for _ in range (len (id ))]
268+ wrapped_id = self ._wrap_id (id )
269+ id2idx = {i : k for k , i in enumerate (wrapped_id )}
270+ request_id = list (filter (lambda i : i in self ._alive_env_ids , wrapped_id ))
271+ result = [[None , None , False , None ] for _ in range (len (wrapped_id ))]
271272
272273 # ask super to step alive envs and remap to current index
273274 if request_id :
@@ -277,7 +278,7 @@ def step(
277278 result [id2idx [i ]][0 ] = self ._postproc_env_obs (result [id2idx [i ]][0 ])
278279
279280 # logging
280- for i , r in zip (id , result ):
281+ for i , r in zip (wrapped_id , result ):
281282 if i in self ._alive_env_ids :
282283 for logger in self ._logger :
283284 logger .on_env_step (i , * r )
0 commit comments