Skip to content

Commit

Permalink
Updated pre-commit configuration and fixed some mypy errors. (faceboo…
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp authored Aug 5, 2022
1 parent 94b76e7 commit 87008d0
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 22.1.0
rev: 22.6.0
hooks:
- id: black
files: 'mbrl'
Expand All @@ -13,11 +13,11 @@ repos:
files: 'mbrl'

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.931
rev: v0.971
hooks:
- id: mypy
files: 'mbrl'
additional_dependencies: [torch, tokenize-rt==3.2.0, types-PyYAML, types-termcolor]
additional_dependencies: [numpy, torch, tokenize-rt==3.2.0, types-PyYAML, types-termcolor]
args: [--no-strict-optional, --ignore-missing-imports]
exclude: setup.py

Expand Down
8 changes: 4 additions & 4 deletions mbrl/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def get_sequence_buffer_iterator(
if use_simple_sampler:
train_iterator: _SequenceIterType = SequenceTransitionSampler(
transitions,
train_trajectories,
train_trajectories, # type:ignore
batch_size,
sequence_length,
max_batches_per_loop_train,
Expand All @@ -322,7 +322,7 @@ def get_sequence_buffer_iterator(
else:
train_iterator = SequenceTransitionIterator(
transitions,
train_trajectories,
train_trajectories, # type: ignore
batch_size,
sequence_length,
ensemble_size,
Expand All @@ -337,7 +337,7 @@ def get_sequence_buffer_iterator(
if use_simple_sampler:
val_iterator = SequenceTransitionSampler(
transitions,
val_trajectories,
val_trajectories, # type: ignore
batch_size,
sequence_length,
max_batches_per_loop_val,
Expand All @@ -346,7 +346,7 @@ def get_sequence_buffer_iterator(
else:
val_iterator = SequenceTransitionIterator(
transitions,
val_trajectories,
val_trajectories, # type: ignore
batch_size,
sequence_length,
1,
Expand Down
10 changes: 5 additions & 5 deletions mbrl/util/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class SequenceTransitionIterator(BootstrapIterator):
def __init__(
self,
transitions: TransitionBatch,
trajectory_indices: List[Tuple[int, int]],
trajectory_indices: Sequence[Tuple[int, int]],
batch_size: int,
sequence_length: int,
ensemble_size: int,
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(

@staticmethod
def _get_indices_valid_starts(
trajectory_indices: List[Tuple[int, int]],
trajectory_indices: Sequence[Tuple[int, int]],
sequence_length: int,
) -> np.ndarray:
# This is memory and time inefficient but it's only done once when creating the
Expand Down Expand Up @@ -332,7 +332,7 @@ class SequenceTransitionSampler(TransitionIterator):
def __init__(
self,
transitions: TransitionBatch,
trajectory_indices: List[Tuple[int, int]],
trajectory_indices: Sequence[Tuple[int, int]],
batch_size: int,
sequence_length: int,
batches_per_loop: int,
Expand Down Expand Up @@ -361,7 +361,7 @@ def __init__(

@staticmethod
def _get_indices_valid_starts(
trajectory_indices: List[Tuple[int, int]],
trajectory_indices: Sequence[Tuple[int, int]],
sequence_length: int,
) -> np.ndarray:
# This is memory and time inefficient but it's only done once when creating the
Expand Down Expand Up @@ -618,7 +618,7 @@ def sample_trajectory(self) -> Optional[TransitionBatch]:
)
return self._batch_from_indices(indices)

def _batch_from_indices(self, indices: Sized) -> TransitionBatch:
def _batch_from_indices(self, indices) -> TransitionBatch:
obs = self.obs[indices]
next_obs = self.next_obs[indices]
action = self.action[indices]
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ exclude =
mbrl/third_party/*

[mypy]
python_version = 3.7
python_version = 3.9
ignore_missing_imports = True
show_error_codes = True
strict_optional = False
Expand Down

0 comments on commit 87008d0

Please sign in to comment.