diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a57593b..f3ac7830 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: 22.1.0 + rev: 22.6.0 hooks: - id: black files: 'mbrl' @@ -13,11 +13,11 @@ repos: files: 'mbrl' - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 + rev: v0.971 hooks: - id: mypy files: 'mbrl' - additional_dependencies: [torch, tokenize-rt==3.2.0, types-PyYAML, types-termcolor] + additional_dependencies: [numpy, torch, tokenize-rt==3.2.0, types-PyYAML, types-termcolor] args: [--no-strict-optional, --ignore-missing-imports] exclude: setup.py diff --git a/mbrl/util/common.py b/mbrl/util/common.py index 98c7013b..18689b23 100644 --- a/mbrl/util/common.py +++ b/mbrl/util/common.py @@ -313,7 +313,7 @@ def get_sequence_buffer_iterator( if use_simple_sampler: train_iterator: _SequenceIterType = SequenceTransitionSampler( transitions, - train_trajectories, + train_trajectories, # type:ignore batch_size, sequence_length, max_batches_per_loop_train, @@ -322,7 +322,7 @@ def get_sequence_buffer_iterator( else: train_iterator = SequenceTransitionIterator( transitions, - train_trajectories, + train_trajectories, # type: ignore batch_size, sequence_length, ensemble_size, @@ -337,7 +337,7 @@ def get_sequence_buffer_iterator( if use_simple_sampler: val_iterator = SequenceTransitionSampler( transitions, - val_trajectories, + val_trajectories, # type: ignore batch_size, sequence_length, max_batches_per_loop_val, @@ -346,7 +346,7 @@ def get_sequence_buffer_iterator( else: val_iterator = SequenceTransitionIterator( transitions, - val_trajectories, + val_trajectories, # type: ignore batch_size, sequence_length, 1, diff --git a/mbrl/util/replay_buffer.py b/mbrl/util/replay_buffer.py index 32c2d2bb..dd5c46de 100644 --- a/mbrl/util/replay_buffer.py +++ b/mbrl/util/replay_buffer.py @@ -233,7 +233,7 @@ class SequenceTransitionIterator(BootstrapIterator): def __init__( self, transitions: TransitionBatch, - trajectory_indices: List[Tuple[int, int]], + trajectory_indices: Sequence[Tuple[int, int]], batch_size: int, sequence_length: int, ensemble_size: int, @@ -266,7 +266,7 @@ def __init__( @staticmethod def _get_indices_valid_starts( - trajectory_indices: List[Tuple[int, int]], + trajectory_indices: Sequence[Tuple[int, int]], sequence_length: int, ) -> np.ndarray: # This is memory and time inefficient but it's only done once when creating the @@ -332,7 +332,7 @@ class SequenceTransitionSampler(TransitionIterator): def __init__( self, transitions: TransitionBatch, - trajectory_indices: List[Tuple[int, int]], + trajectory_indices: Sequence[Tuple[int, int]], batch_size: int, sequence_length: int, batches_per_loop: int, @@ -361,7 +361,7 @@ def __init__( @staticmethod def _get_indices_valid_starts( - trajectory_indices: List[Tuple[int, int]], + trajectory_indices: Sequence[Tuple[int, int]], sequence_length: int, ) -> np.ndarray: # This is memory and time inefficient but it's only done once when creating the @@ -618,7 +618,7 @@ def sample_trajectory(self) -> Optional[TransitionBatch]: ) return self._batch_from_indices(indices) - def _batch_from_indices(self, indices: Sized) -> TransitionBatch: + def _batch_from_indices(self, indices) -> TransitionBatch: obs = self.obs[indices] next_obs = self.next_obs[indices] action = self.action[indices] diff --git a/setup.cfg b/setup.cfg index 0704346f..49ae768c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,7 +11,7 @@ exclude = mbrl/third_party/* [mypy] -python_version = 3.7 +python_version = 3.9 ignore_missing_imports = True show_error_codes = True strict_optional = False