Skip to content

Commit

Permalink
Merge pull request #118 from ageron:cv2_to_tf
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 259620896
Change-Id: I20fc0537bad343d8e8cc6852d67159d008236235
  • Loading branch information
copybara-github committed Jul 23, 2019
2 parents 41527f2 + 944a10f commit de406cf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 51 deletions.
71 changes: 21 additions & 50 deletions tf_agents/environments/atari_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@
from __future__ import print_function

import gin
from gym.spaces.box import Box
from gym import core as gym_core
from gym.spaces import box
import numpy as np
import cv2


@gin.configurable
class AtariPreprocessing(object):
class AtariPreprocessing(gym_core.Wrapper):
"""A class implementing image preprocessing for Atari 2600 agents.
Specifically, this provides the following subset from the JAIR paper
Expand All @@ -54,14 +55,14 @@ class AtariPreprocessing(object):
"""

def __init__(self,
environment,
env,
frame_skip=4,
terminal_on_life_loss=False,
screen_size=84):
"""Constructor for an Atari 2600 preprocessor.
Args:
environment: Gym environment whose observations are preprocessed.
env: Gym environment whose observations are preprocessed.
frame_skip: int, the frequency at which the agent experiences the game.
terminal_on_life_loss: bool, If True, the step() method returns
is_terminal=True whenever a life is lost. See Mnih et al. 2015.
Expand All @@ -70,19 +71,28 @@ def __init__(self,
Raises:
ValueError: if frame_skip or screen_size are not strictly positive.
"""
super(AtariPreprocessing, self).__init__(env)

# Return the observation space adjusted to match the shape of the processed
# observations.
self.observation_space = box.Box(
low=0,
high=255,
shape=(screen_size, screen_size, 1),
dtype=np.uint8)

if frame_skip <= 0:
raise ValueError(
'Frame skip should be strictly positive, got {}'.format(frame_skip))
if screen_size <= 0:
raise ValueError('Target screen size should be strictly positive, got {}'
.format(screen_size))

self.environment = environment
self.terminal_on_life_loss = terminal_on_life_loss
self.frame_skip = frame_skip
self.screen_size = screen_size

obs_dims = self.environment.observation_space
obs_dims = self.env.observation_space
# Stores temporary observations used for pooling over two successive
# frames.
self.screen_buffer = [
Expand All @@ -93,59 +103,20 @@ def __init__(self,
self.game_over = False
self.lives = 0 # Will need to be set by reset().

@property
def observation_space(self):
# Return the observation space adjusted to match the shape of the processed
# observations.
return Box(
low=0,
high=255,
shape=(self.screen_size, self.screen_size, 1),
dtype=np.uint8)

@property
def action_space(self):
return self.environment.action_space

@property
def reward_range(self):
return self.environment.reward_range

@property
def metadata(self):
return self.environment.metadata

def reset(self):
"""Resets the environment.
Returns:
observation: numpy array, the initial observation emitted by the
environment.
"""
self.environment.reset()
self.lives = self.environment.ale.lives()
super(AtariPreprocessing, self).reset()
self.lives = self.env.ale.lives()
self.game_over = False
self._fetch_grayscale_observation(self.screen_buffer[0])
self.screen_buffer[1].fill(0)
return self._pool_and_resize()

def render(self, mode):
"""Renders the current screen, before preprocessing.
This calls the Gym API's render() method.
Args:
mode: Mode argument for the environment's render() method.
Valid values (str) are:
'rgb_array': returns the raw ALE image.
'human': renders to display via the Gym renderer.
Returns:
if mode='rgb_array': numpy array, the most recent screen.
if mode='human': bool, whether the rendering was successful.
"""
return self.environment.render(mode)

def step(self, action):
"""Applies the given action in the environment.
Expand All @@ -172,11 +143,11 @@ def step(self, action):
for time_step in range(self.frame_skip):
# We bypass the Gym observation altogether and directly fetch the
# grayscale image from the ALE. This is a little faster.
_, reward, game_over, info = self.environment.step(action)
_, reward, game_over, info = super(AtariPreprocessing, self).step(action)
accumulated_reward += reward

if self.terminal_on_life_loss:
new_lives = self.environment.ale.lives()
new_lives = self.env.ale.lives()
is_terminal = game_over or new_lives < self.lives
self.lives = new_lives
else:
Expand Down Expand Up @@ -206,7 +177,7 @@ def _fetch_grayscale_observation(self, output):
Returns:
observation: numpy array, the current observation in grayscale.
"""
self.environment.ale.getScreenGrayscale(output)
self.env.ale.getScreenGrayscale(output)
return output

def _pool_and_resize(self):
Expand Down
4 changes: 3 additions & 1 deletion tf_agents/environments/atari_preprocessing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from __future__ import division
from __future__ import print_function

from gym import core as gym_core
import numpy as np
import tensorflow as tf
from tf_agents.environments import atari_preprocessing as preprocessing
Expand All @@ -43,14 +44,15 @@ def getScreenGrayscale(self, screen): # pylint: disable=invalid-name
screen.fill(self.screen_value)


class MockEnvironment(object):
class MockEnvironment(gym_core.Env):
"""Mock environment for testing."""

def __init__(self, screen_size=10, max_steps=10):
self.max_steps = max_steps
self.screen_size = screen_size
self.ale = MockALE()
self.observation_space = np.empty((screen_size, screen_size))
self.action_space = np.empty((5,))
self.game_over = False

def reset(self):
Expand Down

0 comments on commit de406cf

Please sign in to comment.