Skip to content

Commit

Permalink
Release v2.1.0 (#395)
Browse files Browse the repository at this point in the history
* Release v2.1.0

* Fix mypy
  • Loading branch information
araffin authored Aug 17, 2023
1 parent 660f2d3 commit 7f98df9
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 32 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
## Release 2.1.0a0 (WIP)
## Release 2.1.0 (2023-08-17)

### Breaking Changes
- Dropped python 3.7 support
- SB3 now requires PyTorch 1.13+
- Upgraded to SB3 >= 2.1.0
- Upgraded to Huggingface-SB3 >= 2.3
- Upgraded to Optuna >= 3.0
- Upgraded to cloudpickle >= 2.2.1

### New Features
- Added python 3.11 support
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
LINT_PATHS = *.py tests/ scripts/ rl_zoo3/ hyperparams/python/*.py
LINT_PATHS = *.py tests/ scripts/ rl_zoo3/ hyperparams/python/*.py docs/conf.py

# Run pytest and coverage report
pytest:
Expand Down
18 changes: 1 addition & 17 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
#
import os
import sys
from typing import Dict, List
from unittest.mock import MagicMock
from typing import Dict

# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant.
Expand All @@ -37,21 +36,6 @@
sys.path.insert(0, os.path.abspath(".."))


class Mock(MagicMock):
__subclasses__ = [] # type: ignore

@classmethod
def __getattr__(cls, name):
return MagicMock()


# Mock modules that requires C modules
# Note: because of that we cannot test examples using CI
# 'torch', 'torch.nn', 'torch.nn.functional',
# DO not mock modules for now, we will need to do that for read the docs later
MOCK_MODULES: List[str] = []
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)

# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "../rl_zoo3", "version.txt")
with open(version_file) as file_handler:
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
gym==0.26.2
stable-baselines3[extra_no_roms,tests,docs]>=2.0.0
sb3-contrib>=2.0.0
stable-baselines3[extra_no_roms,tests,docs]>=2.1.0
sb3-contrib>=2.1.0
box2d-py==2.3.8
pybullet
# minigrid
# scikit-optimize
optuna~=3.0
pytablewriter~=0.64
pyyaml>=5.1
cloudpickle>=1.5.0
cloudpickle>=2.2.1
plotly
# need to upgrade to gymnasium:
# panda-gym~=3.0.1
rliable>=1.0.5
wandb
huggingface_sb3>=2.2.5
huggingface_sb3>=2.3
seaborn
tqdm
rich
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/gym_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ def step(self, action):
# Patch Gymnasium TimeLimit
gymnasium.wrappers.TimeLimit = PatchedTimeLimit # type: ignore[misc]
gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc]
gymnasium.envs.registration.TimeLimit = PatchedTimeLimit # type: ignore[misc]
gymnasium.envs.registration.TimeLimit = PatchedTimeLimit # type: ignore[misc,attr-defined]
2 changes: 1 addition & 1 deletion rl_zoo3/plots/plot_from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def plot_from_file(): # noqa: C901
for new_key in results_2[key].keys():
results[key][new_key] = results_2[key][new_key]

keys = [key for key in results[list(results.keys())[0]].keys() if key not in args.skip_keys]
keys = [key for key in results[next(iter(results.keys()))].keys() if key not in args.skip_keys]
print(f"keys: {keys}")
if len(args.keep_keys) > 0:
keys = [key for key in keys if key in args.keep_keys]
Expand Down
4 changes: 2 additions & 2 deletions rl_zoo3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_class_name(wrapper_name):
"You should check the indentation."
)
wrapper_dict = wrapper_name
wrapper_name = list(wrapper_dict.keys())[0]
wrapper_name = next(iter(wrapper_dict.keys()))
kwargs = wrapper_dict[wrapper_name]
else:
kwargs = {}
Expand Down Expand Up @@ -178,7 +178,7 @@ def get_callback_list(hyperparams: Dict[str, Any]) -> List[BaseCallback]:
"You should check the indentation."
)
callback_dict = callback_name
callback_name = list(callback_dict.keys())[0]
callback_name = next(iter(callback_dict.keys()))
kwargs = callback_dict[callback_name]
else:
kwargs = {}
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.1.0a0
2.1.0
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
},
entry_points={"console_scripts": ["rl_zoo3=rl_zoo3.cli:main"]},
install_requires=[
"sb3_contrib>=2.0.0",
"sb3_contrib>=2.1.0",
"gym==0.26.2", # for patches to make gym backward compat
"huggingface_sb3>=2.2.5",
"huggingface_sb3>=2.3",
"tqdm",
"rich",
"optuna",
"optuna>=3.0",
"pyyaml>=5.1",
"pytablewriter~=0.64",
# TODO: add test dependencies
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hyperparams_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_optimize_log_path(tmp_path):
assert os.path.isdir(os.path.join(optimization_log_path, "trial_1"))
assert os.path.isfile(os.path.join(optimization_log_path, "trial_1", "evaluations.npz"))

study_path = list(glob.glob(str(tmp_path / algo / "report_*.pkl")))[0]
study_path = next(iter(glob.glob(str(tmp_path / algo / "report_*.pkl"))))
print(study_path)
# Test reading best trials
args = [
Expand Down

0 comments on commit 7f98df9

Please sign in to comment.