From 859db943bd05636c0c5a9bee8351d849a942c1f7 Mon Sep 17 00:00:00 2001 From: Mark Stephenson Date: Mon, 11 Nov 2024 15:52:49 -0700 Subject: [PATCH] Issue #207: Support Gymnasium 1.0.0 --- pyproject.toml | 4 ++-- src/bsk_rl/obs/observations.py | 8 +++----- tests/integration/test_int_gym_env.py | 9 +++++++-- tests/unittest/obs/test_observations.py | 8 ++++---- tests/unittest/test_gym_env.py | 5 ++++- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d4b2989..b7cc8361 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ requires-python = ">=3.10.0" license = { text = "MIT" } dependencies = [ "Deprecated", - "gymnasium<1.0.0", + "gymnasium", "numpy", "pandas", "pettingzoo", @@ -29,7 +29,7 @@ dependencies = [ [project.optional-dependencies] docs = ["ipykernel", "ipywidgets", "nbdime", "nbsphinx", "sphinx-rtd-theme"] -rllib = ["dm_tree", "pyarrow", "ray[rllib]", "scikit-image", "torch", "typer"] +rllib = ["dm_tree", "pyarrow", "ray[rllib]==2.35.0", "scikit-image", "torch", "typer"] [project.scripts] finish_install = "bsk_rl.finish_install:pck_install" diff --git a/src/bsk_rl/obs/observations.py b/src/bsk_rl/obs/observations.py index 43a48256..324ef3da 100644 --- a/src/bsk_rl/obs/observations.py +++ b/src/bsk_rl/obs/observations.py @@ -37,18 +37,17 @@ def nested_obs_to_space(obs_dict): ) elif isinstance(obs_dict, list): return spaces.Box( - low=-1e16, high=1e16, shape=(len(obs_dict),), dtype=np.float64 + low=-1e16, high=1e16, shape=(len(obs_dict),), dtype=np.float32 ) elif isinstance(obs_dict, (float, int)): - return spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64) + return spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float32) elif isinstance(obs_dict, np.ndarray): - return spaces.Box(low=-1e16, high=1e16, shape=obs_dict.shape, dtype=np.float64) + return spaces.Box(low=-1e16, high=1e16, shape=obs_dict.shape, dtype=np.float32) else: raise TypeError(f"Cannot convert {obs_dict} to gym space.") class ObservationBuilder: - def __init__(self, satellite: "Satellite", obs_type: type = np.ndarray) -> None: """Satellite subclass for composing observations. @@ -312,7 +311,6 @@ def _r_LB_H(sat, opp): class OpportunityProperties(Observation): - _fn_map = { "priority": lambda sat, opp: opp["object"].priority, "r_LP_P": lambda sat, opp: opp["r_LP_P"], diff --git a/tests/integration/test_int_gym_env.py b/tests/integration/test_int_gym_env.py index 4d9d50b5..8decd41d 100644 --- a/tests/integration/test_int_gym_env.py +++ b/tests/integration/test_int_gym_env.py @@ -35,7 +35,9 @@ def test_action_space(self): assert self.env.action_space == spaces.Discrete(1) def test_observation_space(self): - assert self.env.observation_space == spaces.Box(-1e16, 1e16, (1,)) + assert self.env.observation_space == spaces.Box( + -1e16, 1e16, (1,), dtype=np.float32 + ) def test_step(self): observation, reward, terminated, truncated, info = self.env.step(0) @@ -124,7 +126,10 @@ def test_action_space(self): def test_observation_space(self): assert self.env.observation_space == spaces.Tuple( - (spaces.Box(-1e16, 1e16, (1,)), spaces.Box(-1e16, 1e16, (1,))) + ( + spaces.Box(-1e16, 1e16, (1,), dtype=np.float32), + spaces.Box(-1e16, 1e16, (1,), dtype=np.float32), + ) ) def test_step(self): diff --git a/tests/unittest/obs/test_observations.py b/tests/unittest/obs/test_observations.py index a689101a..721ee431 100644 --- a/tests/unittest/obs/test_observations.py +++ b/tests/unittest/obs/test_observations.py @@ -69,23 +69,23 @@ def test_obs_cache(self): [ ( np.array([1]), - spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64), + spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float32), ), ( np.array([1, 2]), - spaces.Box(low=-1e16, high=1e16, shape=(2,), dtype=np.float64), + spaces.Box(low=-1e16, high=1e16, shape=(2,), dtype=np.float32), ), ( {"a": 1, "b": {"c": 1}}, spaces.Dict( { "a": spaces.Box( - low=-1e16, high=1e16, shape=(1,), dtype=np.float64 + low=-1e16, high=1e16, shape=(1,), dtype=np.float32 ), "b": spaces.Dict( { "c": spaces.Box( - low=-1e16, high=1e16, shape=(1,), dtype=np.float64 + low=-1e16, high=1e16, shape=(1,), dtype=np.float32 ) } ), diff --git a/tests/unittest/test_gym_env.py b/tests/unittest/test_gym_env.py index 91744701..eb943bb2 100644 --- a/tests/unittest/test_gym_env.py +++ b/tests/unittest/test_gym_env.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock, patch +import numpy as np import pytest from gymnasium import spaces @@ -132,7 +133,9 @@ def test_get_obs_retasking_only(self): satellites=[ MagicMock( get_obs=MagicMock(return_value=[i + 1]), - observation_space=spaces.Box(-1e9, 1e9, shape=(1,)), + observation_space=spaces.Box( + -1e9, 1e9, shape=(1,), dtype=np.float32 + ), requires_retasking=(i == 1), ) for i in range(3)