Skip to content

Commit

Permalink
Issue #207: Support Gymnasium 1.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark2000 committed Nov 15, 2024
1 parent 45b941e commit 859db94
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 14 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ requires-python = ">=3.10.0"
license = { text = "MIT" }
dependencies = [
"Deprecated",
"gymnasium<1.0.0",
"gymnasium",
"numpy",
"pandas",
"pettingzoo",
Expand All @@ -29,7 +29,7 @@ dependencies = [

[project.optional-dependencies]
docs = ["ipykernel", "ipywidgets", "nbdime", "nbsphinx", "sphinx-rtd-theme"]
rllib = ["dm_tree", "pyarrow", "ray[rllib]", "scikit-image", "torch", "typer"]
rllib = ["dm_tree", "pyarrow", "ray[rllib]==2.35.0", "scikit-image", "torch", "typer"]

[project.scripts]
finish_install = "bsk_rl.finish_install:pck_install"
8 changes: 3 additions & 5 deletions src/bsk_rl/obs/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,17 @@ def nested_obs_to_space(obs_dict):
)
elif isinstance(obs_dict, list):
return spaces.Box(
low=-1e16, high=1e16, shape=(len(obs_dict),), dtype=np.float64
low=-1e16, high=1e16, shape=(len(obs_dict),), dtype=np.float32
)
elif isinstance(obs_dict, (float, int)):
return spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64)
return spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float32)
elif isinstance(obs_dict, np.ndarray):
return spaces.Box(low=-1e16, high=1e16, shape=obs_dict.shape, dtype=np.float64)
return spaces.Box(low=-1e16, high=1e16, shape=obs_dict.shape, dtype=np.float32)
else:
raise TypeError(f"Cannot convert {obs_dict} to gym space.")


class ObservationBuilder:

def __init__(self, satellite: "Satellite", obs_type: type = np.ndarray) -> None:
"""Satellite subclass for composing observations.
Expand Down Expand Up @@ -312,7 +311,6 @@ def _r_LB_H(sat, opp):


class OpportunityProperties(Observation):

_fn_map = {
"priority": lambda sat, opp: opp["object"].priority,
"r_LP_P": lambda sat, opp: opp["r_LP_P"],
Expand Down
9 changes: 7 additions & 2 deletions tests/integration/test_int_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def test_action_space(self):
assert self.env.action_space == spaces.Discrete(1)

def test_observation_space(self):
assert self.env.observation_space == spaces.Box(-1e16, 1e16, (1,))
assert self.env.observation_space == spaces.Box(
-1e16, 1e16, (1,), dtype=np.float32
)

def test_step(self):
observation, reward, terminated, truncated, info = self.env.step(0)
Expand Down Expand Up @@ -124,7 +126,10 @@ def test_action_space(self):

def test_observation_space(self):
assert self.env.observation_space == spaces.Tuple(
(spaces.Box(-1e16, 1e16, (1,)), spaces.Box(-1e16, 1e16, (1,)))
(
spaces.Box(-1e16, 1e16, (1,), dtype=np.float32),
spaces.Box(-1e16, 1e16, (1,), dtype=np.float32),
)
)

def test_step(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/unittest/obs/test_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,23 @@ def test_obs_cache(self):
[
(
np.array([1]),
spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64),
spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float32),
),
(
np.array([1, 2]),
spaces.Box(low=-1e16, high=1e16, shape=(2,), dtype=np.float64),
spaces.Box(low=-1e16, high=1e16, shape=(2,), dtype=np.float32),
),
(
{"a": 1, "b": {"c": 1}},
spaces.Dict(
{
"a": spaces.Box(
low=-1e16, high=1e16, shape=(1,), dtype=np.float64
low=-1e16, high=1e16, shape=(1,), dtype=np.float32
),
"b": spaces.Dict(
{
"c": spaces.Box(
low=-1e16, high=1e16, shape=(1,), dtype=np.float64
low=-1e16, high=1e16, shape=(1,), dtype=np.float32
)
}
),
Expand Down
5 changes: 4 additions & 1 deletion tests/unittest/test_gym_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
from gymnasium import spaces

Expand Down Expand Up @@ -132,7 +133,9 @@ def test_get_obs_retasking_only(self):
satellites=[
MagicMock(
get_obs=MagicMock(return_value=[i + 1]),
observation_space=spaces.Box(-1e9, 1e9, shape=(1,)),
observation_space=spaces.Box(
-1e9, 1e9, shape=(1,), dtype=np.float32
),
requires_retasking=(i == 1),
)
for i in range(3)
Expand Down

0 comments on commit 859db94

Please sign in to comment.