Skip to content

Commit

Permalink
added STAND move
Browse files Browse the repository at this point in the history
  • Loading branch information
NullDefault committed Jul 5, 2021
1 parent 9d67f31 commit 277381b
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 12 deletions.
18 changes: 11 additions & 7 deletions gym_stag_hunt/demos.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from time import sleep

from gym_stag_hunt.envs.simple import SimpleEnv
from gym_stag_hunt.envs.hunt import HuntEnv
from gym_stag_hunt.envs.escalation import EscalationEnv
from gym_stag_hunt.envs.harvest import HarvestEnv
from gym_stag_hunt.src.games.abstract_grid_game import UP, LEFT, DOWN, RIGHT
from gym_stag_hunt.envs.hunt import HuntEnv
from gym_stag_hunt.envs.simple import SimpleEnv
from gym_stag_hunt.src.games.abstract_grid_game import UP, LEFT, DOWN, RIGHT, STAND

ENVS = {
'CLASSIC': SimpleEnv,
'HUNT': HuntEnv,
'HARVEST': HarvestEnv,
'ESCALATION': EscalationEnv
}
ENV = 'ESCALATION'


def print_ep(obs, reward, done, info):
Expand All @@ -29,7 +28,8 @@ def dir_parse(key):
LEFT: "LEFT",
UP: "UP",
DOWN: "DOWN",
RIGHT: "RIGHT"
RIGHT: "RIGHT",
STAND: "STAND"
}
return d[key]

Expand All @@ -44,15 +44,19 @@ def manual_input():
i = DOWN
elif i in ['d', 'D']:
i = RIGHT
elif i in ['x', 'X']:
i = STAND

return i


ENV = 'HUNT'

if __name__ == "__main__":
env = ENVS[ENV](obs_type='image', opponent_policy='pursuit')
env = ENVS[ENV](obs_type='image', enable_multiagent=True)
obs = env.reset()
for i in range(10000):
actions = env.game._seek_entity(env.game.A_AGENT, env.game.MARK)
actions = [env.action_space.sample(), env.action_space.sample()]

obs, rewards, done, info = env.step(actions=actions)
# print_ep(obs, rewards, done, info)
Expand Down
2 changes: 1 addition & 1 deletion gym_stag_hunt/envs/escalation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self,
opponent_policy=opponent_policy)

# Environment Config
self.action_space = Discrete(4) # up, down, left, right on the grid
self.action_space = Discrete(5) # up, down, left, right or stand
if obs_type == 'image': # Observation is the rgb pixel array
self.observation_space = Box(0, 255, shape=(grid_size[0]*TILE_SIZE, grid_size[1]*TILE_SIZE, 3), dtype=uint8)
elif obs_type == 'coords':
Expand Down
2 changes: 1 addition & 1 deletion gym_stag_hunt/envs/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self,
young_reward=young_reward,
mature_reward=mature_reward)

self.action_space = Discrete(4) # up, down, left, right on the grid
self.action_space = Discrete(5) # up, down, left, right or stand

if obs_type == 'image':
self.observation_space = Box(0, 255, shape=(grid_size[0]*TILE_SIZE, grid_size[1]*TILE_SIZE, 3), dtype=uint8)
Expand Down
2 changes: 1 addition & 1 deletion gym_stag_hunt/envs/hunt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self,
mauling_punishment=mauling_punishment,
opponent_policy=opponent_policy)

self.action_space = Discrete(4) # up, down, left, right on the grid
self.action_space = Discrete(5) # up, down, left, right or stand

if obs_type == 'image':
self.observation_space = Box(0, 255, shape=(grid_size[0]*TILE_SIZE, grid_size[1]*TILE_SIZE, 3), dtype=uint8)
Expand Down
9 changes: 7 additions & 2 deletions gym_stag_hunt/src/games/abstract_grid_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
DOWN = 1
RIGHT = 2
UP = 3
STAND = 4


class AbstractGridGame(ABC):
Expand Down Expand Up @@ -59,7 +60,8 @@ def _move_dispatcher(self):
LEFT: self._move_left,
DOWN: self._move_down,
RIGHT: self._move_right,
UP: self._move_up
UP: self._move_up,
STAND: self._stand
}

def _move_entity(self, entity_pos, action):
Expand Down Expand Up @@ -120,7 +122,7 @@ def _seek_entity(self, seeker, target):
options.append(DOWN)

if not options:
options = [LEFT, DOWN, RIGHT, UP]
options = [STAND]
shipback = choice(options)

return shipback
Expand Down Expand Up @@ -165,6 +167,9 @@ def _move_down(self, pos):
new_y = self.GRID_H - 1
return pos[0], new_y

def _stand(self, pos):
return pos

"""
Properties
"""
Expand Down

0 comments on commit 277381b

Please sign in to comment.