Skip to content

Commit

Permalink
Update gym to gymnasium
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhu committed Feb 21, 2024
1 parent 99cd11a commit 6436488
Show file tree
Hide file tree
Showing 17 changed files with 47 additions and 39 deletions.
5 changes: 3 additions & 2 deletions algorithms/utils/act.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import gymnasium as gym
import torch
import torch.nn as nn
import gym.spaces
from .mlp import MLPLayer

from .distributions import BetaShootBernoulli, Categorical, DiagGaussian, Bernoulli
from .mlp import MLPLayer


class ACTLayer(nn.Module):
Expand Down
19 changes: 10 additions & 9 deletions algorithms/utils/flatten.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import gym.spaces
import gymnasium.spaces
import numpy as np
import gymnasium as gym
from collections import OrderedDict


def build_flattener(space):
if isinstance(space, gym.spaces.Dict):
if isinstance(space, gymnasium.spaces.Dict):
return DictFlattener(space)
elif isinstance(space, gym.spaces.Box) \
or isinstance(space, gym.spaces.MultiDiscrete):
elif isinstance(space, gymnasium.spaces.Box) \
or isinstance(space, gymnasium.spaces.MultiDiscrete):
return BoxFlattener(space)
elif isinstance(space, gym.spaces.Discrete):
elif isinstance(space, gymnasium.spaces.Discrete):
return DiscreteFlattener(space)
else:
raise NotImplementedError
Expand All @@ -21,15 +22,15 @@ class DictFlattener():

def __init__(self, ori_space):
self.space = ori_space
assert isinstance(ori_space, gym.spaces.Dict)
assert isinstance(ori_space, gymnasium.spaces.Dict)
self.size = 0
self.flatteners = OrderedDict()
for name, space in self.space.spaces.items():
if isinstance(space, gym.spaces.Box):
if isinstance(space, gymnasium.spaces.Box):
flattener = BoxFlattener(space)
elif isinstance(space, gym.spaces.Discrete):
elif isinstance(space, gymnasium.spaces.Discrete):
flattener = DiscreteFlattener(space)
elif isinstance(space, gym.spaces.Dict):
elif isinstance(space, gymnasium.spaces.Dict):
flattener = DictFlattener(space)
self.flatteners[name] = flattener
self.size += flattener.size
Expand Down
3 changes: 2 additions & 1 deletion algorithms/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import math
import gym.spaces

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
Expand Down
7 changes: 5 additions & 2 deletions envs/JSBSim/core/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from collections import namedtuple
from numpy.linalg import norm
from gym.spaces import Box, Discrete
from gymnasium.spaces import Box, Discrete
from ..utils.utils import in_range_deg

"""
Expand Down Expand Up @@ -549,7 +549,10 @@ def add_jsbsim_props(self, jsbsim_props):
Args:
jsbsim_props (list): list of 'name_jsbsim (access)' of jsbsim properties
"""
for jsbsim_prop in jsbsim_props:
jsbsim_props_tmp=jsbsim_props.split("\n")
for jsbsim_prop in jsbsim_props_tmp:
if jsbsim_prop.strip() == "":
continue # skip empty line
[name_jsbsim, access] = jsbsim_prop.split(" ")
access = re.sub(r"[\(\)]", "", access) # remove parenthesis from the flag
name = re.sub(r"_$", "", re.sub(r"[\-/\]\[]+", "_", name_jsbsim)) # get property name from jsbsim name
Expand Down
10 changes: 5 additions & 5 deletions envs/JSBSim/envs/env_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import gym
from gym.utils import seeding
import gymnasium
from gymnasium.utils import seeding
import numpy as np
from typing import Dict, Any, Tuple
from ..core.simulatior import AircraftSimulator, BaseSimulator
from ..tasks.task_base import BaseTask
from ..utils.utils import parse_config


class BaseEnv(gym.Env):
class BaseEnv(gymnasium.Env):
"""
A class wrapping the JSBSim flight dynamics module (FDM) for simulating
aircraft as an RL environment conforming to the OpenAI Gym Env
Expand Down Expand Up @@ -35,11 +35,11 @@ def num_agents(self) -> int:
return self.task.num_agents

@property
def observation_space(self) -> gym.Space:
def observation_space(self) -> gymnasium.Space:
return self.task.observation_space

@property
def action_space(self) -> gym.Space:
def action_space(self) -> gymnasium.Space:
return self.task.action_space

@property
Expand Down
2 changes: 1 addition & 1 deletion envs/JSBSim/tasks/heading_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from gym import spaces
from gymnasium import spaces
from .task_base import BaseTask
from ..core.catalog import Catalog as c
from ..reward_functions import AltitudeReward, HeadingReward
Expand Down
2 changes: 1 addition & 1 deletion envs/JSBSim/tasks/multiplecombat_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from gym import spaces
from gymnasium import spaces
from typing import Tuple
import torch

Expand Down
2 changes: 1 addition & 1 deletion envs/JSBSim/tasks/singlecombat_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np
from gym import spaces
from gymnasium import spaces
from typing import Literal
from .task_base import BaseTask
from ..core.simulatior import AircraftSimulator
Expand Down
2 changes: 1 addition & 1 deletion envs/JSBSim/tasks/singlecombat_with_missle_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from gym import spaces
from gymnasium import spaces
from collections import deque

from .singlecombat_task import SingleCombatTask, HierarchicalSingleCombatTask
Expand Down
2 changes: 1 addition & 1 deletion envs/JSBSim/tasks/task_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from gym import spaces
from gymnasium import spaces
from typing import List, Tuple
from abc import ABC, abstractmethod
from ..core.catalog import Catalog as c
Expand Down
2 changes: 1 addition & 1 deletion envs/JSBSim/test/test_baseline_use_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_maneuver():
print(info)
break
step += 1
# plt.plot(reward_list)
plt.plot(reward_list)
# plt.savefig('rewards.png')


Expand Down
12 changes: 6 additions & 6 deletions runner/share_jsbsim_runner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import time
from matplotlib.pyplot import axis
import torch
import logging
import numpy as np
from gym import spaces
import time
from typing import List
from .base_runner import Runner

import numpy as np
import torch

from algorithms.utils.buffer import SharedReplayBuffer
from .base_runner import Runner


def _t2n(x):
Expand Down
4 changes: 2 additions & 2 deletions scripts/train/train_gym.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
import sys
import os
import gym
import gymnasium as gym
import wandb
import socket
import torch
Expand Down Expand Up @@ -150,7 +150,7 @@ def main(args):
if all_args.use_wandb:
run = wandb.init(config=all_args,
project=all_args.env_name,
entity=all_args.wandb_name,
# entity=all_args.wandb_name,
notes=socket.gethostname(),
name=f"{all_args.experiment_name}_seed{all_args.seed}",
group=all_args.scenario_name,
Expand Down
1 change: 0 additions & 1 deletion scripts/train/train_jsbsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def main(args):
if all_args.use_wandb:
run = wandb.init(config=all_args,
project=all_args.env_name,
entity=all_args.wandb_name,
notes=socket.gethostname(),
name=f"{all_args.experiment_name}_seed{all_args.seed}",
group=all_args.scenario_name,
Expand Down
1 change: 1 addition & 0 deletions scripts/train_selfplay_shoot.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/bin/sh

env="SingleCombat"
scenario="1v1/ShootMissile/HierarchySelfplay"
algo="ppo"
Expand Down
1 change: 1 addition & 0 deletions scripts/train_share_selfplay.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/bin/sh

env="MultipleCombat"
scenario="2v2/NoWeapon/HierarchySelfplay"
algo="mappo"
Expand Down
11 changes: 6 additions & 5 deletions tests/test_ppo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import sys
import os
import torch
import pytest
import numpy as np
import sys
from itertools import product
import gym.spaces

import gymnasium as gym
import numpy as np
import pytest
import torch

sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

Expand Down

0 comments on commit 6436488

Please sign in to comment.