Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
dgriff777 authored May 11, 2018
1 parent 9785a30 commit 47e421f
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def atari_env(env_id, env_conf, args):
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env._max_episode_steps = args.max_episode_length
env = AtariRescale(env, env_conf)
env = NormalizedEnv(env)
return env
Expand Down Expand Up @@ -73,38 +74,45 @@ def __init__(self, env, noop_max=30):
"""
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

def reset(self):
def reset(self, **kwargs):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset()
noops = random.randrange(1, self.noop_max + 1) # pylint: disable=E1101
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs = self.env.reset(**kwargs)
return obs

def step(self, ac):
return self.env.step(ac)



class FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3

def reset(self):
self.env.reset()
def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
self.env.reset(**kwargs)
return obs

def step(self, ac):
Expand All @@ -118,62 +126,60 @@ def __init__(self, env):
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True
self.was_real_done = True

def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = True
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = info['ale.lives']
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert sometimes we stay in lives == 0 condtion for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.was_real_done = False
self.lives = lives
return obs, reward, done, info
return obs, reward, done, self.was_real_done

def reset(self):
def reset(self, **kwargs):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
self.lives = 0
obs = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, _, info = self.env.step(0)
self.lives = info['ale.lives']
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs


class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
def __init__(self, env, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = deque(maxlen=2)
self._skip = skip
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
self._skip = skip

def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for _ in range(self._skip):
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
if i == self._skip - 2: self._obs_buffer[0] = obs
if i == self._skip - 1: self._obs_buffer[1] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)

max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info

def reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
def reset(self, **kwargs):
return self.env.reset(**kwargs)

0 comments on commit 47e421f

Please sign in to comment.