Skip to content

Commit

Permalink
Fix lint and mypy errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp committed Dec 4, 2021
1 parent 044e4c1 commit 242c353
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 35 deletions.
11 changes: 8 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
repos:
- repo: https://github.com/psf/black
rev: 20.8b1
rev: 21.9b0
hooks:
- id: black
files: 'mbrl'
language_version: python3.7

- repo: https://gitlab.com/pycqa/flake8
rev: 3.7.9
hooks:
- id: flake8
files: 'mbrl'
additional_dependencies: [-e, "git+git://github.com/pycqa/pyflakes.git@1911c20#egg=pyflakes"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
hooks:
- id: mypy
files: 'mbrl'
additional_dependencies: [torch, tokenize-rt==3.2.0]
args: [--no-strict-optional, --ignore-missing-imports]
exclude: setup.py
Expand All @@ -23,6 +26,8 @@ repos:
rev: 5.5.2
hooks:
- id: isort
files: 'mbrl/.*'
args: ["--profile", "black"]
files: 'mbrl'
- id: isort
files: 'tests/.*'
files: 'tests'
args: [ "--profile", "black" ]
8 changes: 5 additions & 3 deletions mbrl/algorithms/planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from mbrl.models import ModelEnv, ModelTrainer
from mbrl.planning import RandomAgent, create_trajectory_optim_agent_for_model
from mbrl.util import Logger
from mbrl.util.common import (create_replay_buffer,
get_sequence_buffer_iterator,
rollout_agent_trajectories)
from mbrl.util.common import (
create_replay_buffer,
get_sequence_buffer_iterator,
rollout_agent_trajectories,
)

METRICS_LOG_FORMAT = [
("observations_loss", "OL", "float"),
Expand Down
18 changes: 13 additions & 5 deletions mbrl/diagnostics/training_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,21 @@

import pandas as pd
import yaml
from matplotlib.backends.backend_qt5agg import \
FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from PyQt5.QtCore import QAbstractTableModel, QDir, Qt
from PyQt5.QtWidgets import (QAbstractItemView, QApplication, QCheckBox,
QDockWidget, QFileDialog, QHeaderView,
QMainWindow, QPushButton, QTableView, QToolBar)
from PyQt5.QtWidgets import (
QAbstractItemView,
QApplication,
QCheckBox,
QDockWidget,
QFileDialog,
QHeaderView,
QMainWindow,
QPushButton,
QTableView,
QToolBar,
)

MULTI_ROOT = "multirun.yaml"
SOURCE = "results.csv"
Expand Down
8 changes: 6 additions & 2 deletions mbrl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,9 @@
from .model_trainer import ModelTrainer
from .one_dim_tr_model import OneDTransitionRewardModel
from .planet import PlaNetModel
from .util import (Conv2dDecoder, Conv2dEncoder, EnsembleLinearLayer,
truncated_normal_init)
from .util import (
Conv2dDecoder,
Conv2dEncoder,
EnsembleLinearLayer,
truncated_normal_init,
)
11 changes: 8 additions & 3 deletions mbrl/planning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .core import Agent, RandomAgent, complete_agent_cfg, load_agent
from .trajectory_opt import (CEMOptimizer, ICEMOptimizer, MPPIOptimizer,
TrajectoryOptimizer, TrajectoryOptimizerAgent,
create_trajectory_optim_agent_for_model)
from .trajectory_opt import (
CEMOptimizer,
ICEMOptimizer,
MPPIOptimizer,
TrajectoryOptimizer,
TrajectoryOptimizerAgent,
create_trajectory_optim_agent_for_model,
)
8 changes: 6 additions & 2 deletions mbrl/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
import omegaconf

from .logger import Logger
from .replay_buffer import (ReplayBuffer, SequenceTransitionIterator,
SequenceTransitionSampler, TransitionIterator)
from .replay_buffer import (
ReplayBuffer,
SequenceTransitionIterator,
SequenceTransitionSampler,
TransitionIterator,
)


def create_handler(cfg: Union[Dict, omegaconf.ListConfig, omegaconf.DictConfig]):
Expand Down
10 changes: 7 additions & 3 deletions mbrl/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
import mbrl.planning
import mbrl.types

from .replay_buffer import (BootstrapIterator, ReplayBuffer,
SequenceTransitionIterator,
SequenceTransitionSampler, TransitionIterator)
from .replay_buffer import (
BootstrapIterator,
ReplayBuffer,
SequenceTransitionIterator,
SequenceTransitionSampler,
TransitionIterator,
)


def create_one_dim_tr_model(
Expand Down
2 changes: 1 addition & 1 deletion mbrl/util/dmcontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __exit__(self, *_args):


class DmcontrolEnvHandler(EnvHandler):
""" Env handler for Dmcontrol-backed gym envs """
"""Env handler for Dmcontrol-backed gym envs"""

freeze = FreezeDmcontrol

Expand Down
4 changes: 2 additions & 2 deletions mbrl/util/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _legacy_make_env(


class Freeze(ABC):
""" Abstract base class for freezing various gym backends """
"""Abstract base class for freezing various gym backends"""

def __enter__(self, env):
raise NotImplementedError
Expand All @@ -121,7 +121,7 @@ class EnvHandler(ABC):
@staticmethod
@abstractmethod
def is_correct_env_type(env: gym.wrappers.TimeLimit) -> bool:
""" Checks that the env being handled is of the correct type """
"""Checks that the env being handled is of the correct type"""
raise NotImplementedError

@staticmethod
Expand Down
23 changes: 12 additions & 11 deletions mbrl/util/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import tempfile
from typing import Tuple
from typing import Callable, List, Tuple

import gym
import gym.wrappers
import numpy as np

# Need to import pybulletgym to register pybullet envs.
# Ignore the flake8 error generated
import pybulletgym # noqa
from pybulletgym.envs.mujoco.envs.env_bases import \
BaseBulletEnv as MJBaseBulletEnv
from pybulletgym.envs.mujoco.robots.locomotors.walker_base import \
WalkerBase as MJWalkerBase
from pybulletgym.envs.roboschool.envs.env_bases import \
BaseBulletEnv as RSBaseBulletEnv
from pybulletgym.envs.roboschool.robots.locomotors.walker_base import \
WalkerBase as RSWalkerBase
from pybulletgym.envs.mujoco.envs.env_bases import BaseBulletEnv as MJBaseBulletEnv
from pybulletgym.envs.mujoco.robots.locomotors.walker_base import (
WalkerBase as MJWalkerBase,
)
from pybulletgym.envs.roboschool.envs.env_bases import BaseBulletEnv as RSBaseBulletEnv
from pybulletgym.envs.roboschool.robots.locomotors.walker_base import (
WalkerBase as RSWalkerBase,
)

from mbrl.util.env import EnvHandler, Freeze

Expand Down Expand Up @@ -66,7 +67,7 @@ def __exit__(self, *_args):


class PybulletEnvHandler(EnvHandler):
""" Env handler for PyBullet-backed gym envs """
"""Env handler for PyBullet-backed gym envs"""

freeze = FreezePybullet

Expand Down Expand Up @@ -160,7 +161,7 @@ def _get_current_state_locomotion(env: gym.wrappers.TimeLimit) -> Tuple:
ground_ids = env.ground_ids
potential = env.potential
reward = float(env.reward)
robot_keys = [
robot_keys: List[Tuple[str, Callable]] = [
("body_rpy", tuple),
("body_xyz", tuple),
("feet_contact", np.copy),
Expand Down

0 comments on commit 242c353

Please sign in to comment.