Skip to content

Commit

Permalink
migrated observation spaces to be unint8's and added attribute errors…
Browse files Browse the repository at this point in the history
… for when grid size is too large
  • Loading branch information
NullDefault committed Jun 18, 2021
1 parent 9b45e1d commit c99a90c
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 13 deletions.
4 changes: 2 additions & 2 deletions gym_stag_hunt/demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
'HARVEST': HarvestEnv,
'ESCALATION': EscalationEnv
}
ENV = 'HUNT'
ENV = 'ESCALATION'

if __name__ == "__main__":
env = ENVS[ENV](obs_type='coord', load_renderer=True)
env = ENVS[ENV](obs_type='image', load_renderer=True)
env.reset()
for i in range(10000):
actions = [env.action_space.sample(), env.action_space.sample()]
Expand Down
9 changes: 8 additions & 1 deletion gym_stag_hunt/envs/abstract_markov_staghunt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@


class AbstractMarkovStagHuntEnv(Env, ABC):
metadata = {'render.modes': ['human']}
metadata = {
'render.modes': ['human'],
'obs.types': ['image', 'coords']
}

def __init__(self,
grid_size=(5, 5),
Expand All @@ -19,6 +22,10 @@ def __init__(self,
total_cells = grid_size[0] * grid_size[1]
if total_cells < 3:
raise AttributeError('Grid is too small. Please specify a larger grid size.')
if obs_type not in self.metadata['obs.types']:
raise AttributeError('Invalid observation type provided. Please specify "image" or "coords"')
if grid_size[0] >= 255 or grid_size[1] >= 255:
raise AttributeError('Grid is too large. Please specify a smaller grid size.')

super(AbstractMarkovStagHuntEnv, self).__init__()

Expand Down
7 changes: 4 additions & 3 deletions gym_stag_hunt/envs/escalation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from gym.spaces import Discrete, Box
from numpy import int64, Inf
from numpy import Inf, uint8

from gym_stag_hunt.envs.abstract_markov_staghunt import AbstractMarkovStagHuntEnv
from gym_stag_hunt.src.entities import TILE_SIZE
from gym_stag_hunt.src.games.escalation_game import Escalation


Expand Down Expand Up @@ -38,8 +39,8 @@ def __init__(self,
# Environment Config
self.action_space = Discrete(4) # up, down, left, right on the grid
if obs_type == 'image': # Observation is the rgb pixel array
self.observation_space = Box(0, 255, shape=(screen_size[0], screen_size[1], 3), dtype=int64)
self.observation_space = Box(0, 255, shape=(grid_size[0]*TILE_SIZE, grid_size[1]*TILE_SIZE, 3), dtype=uint8)
elif obs_type == 'coords':
self.observation_space = Box(0, max(grid_size), shape=(3, 2), dtype=int)
self.observation_space = Box(0, max(grid_size), shape=(3, 2), dtype=uint8)

self.reward_range = (-Inf, Inf) # There is technically no limit on how high or low the reinforcement can be
9 changes: 5 additions & 4 deletions gym_stag_hunt/envs/harvest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from gym.spaces import Discrete, Box, Tuple
from numpy import int64
from numpy import uint8

from gym_stag_hunt.envs.abstract_markov_staghunt import AbstractMarkovStagHuntEnv
from gym_stag_hunt.src.entities import TILE_SIZE
from gym_stag_hunt.src.games.harvest_game import Harvest


Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(self,
self.action_space = Discrete(4) # up, down, left, right on the grid

if obs_type == 'image':
self.observation_space = Box(0, 255, shape=(screen_size[0], screen_size[1], 3), dtype=int64)
self.observation_space = Box(0, 255, shape=(grid_size[0]*TILE_SIZE, grid_size[1]*TILE_SIZE, 3), dtype=uint8)
elif obs_type == 'coords':
self.observation_space = Tuple((Box(0, max(grid_size), shape=(2, 2), dtype=int),
Box(0, max(grid_size), shape=(max_plants, 3), dtype=int)))
self.observation_space = Tuple((Box(0, max(grid_size), shape=(2, 2), dtype=uint8),
Box(0, max(grid_size), shape=(max_plants, 3), dtype=uint8)))
7 changes: 4 additions & 3 deletions gym_stag_hunt/envs/hunt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from gym.spaces import Discrete, Box
from numpy import int64
from numpy import uint8

from gym_stag_hunt.envs.abstract_markov_staghunt import AbstractMarkovStagHuntEnv
from gym_stag_hunt.src.entities import TILE_SIZE
from gym_stag_hunt.src.games.staghunt_game import StagHunt


Expand Down Expand Up @@ -70,6 +71,6 @@ def __init__(self,
self.action_space = Discrete(4) # up, down, left, right on the grid

if obs_type == 'image':
self.observation_space = Box(0, 255, shape=(screen_size[0], screen_size[1], 3), dtype=int64)
self.observation_space = Box(0, 255, shape=(grid_size[0]*TILE_SIZE, grid_size[1]*TILE_SIZE, 3), dtype=uint8)
elif obs_type == 'coords':
self.observation_space = Box(0, max(grid_size), shape=(3+forage_quantity, 2), dtype=int)
self.observation_space = Box(0, max(grid_size), shape=(3+forage_quantity, 2), dtype=uint8)

0 comments on commit c99a90c

Please sign in to comment.