-
Notifications
You must be signed in to change notification settings - Fork 0
/
wrappers.py
35 lines (30 loc) · 1.1 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import cv2
import gym
import numpy as np
IMAGE_SIZE = (32, 32)
class InputTransformation(gym.ObservationWrapper):
"""
A wrapper class that transforms the enivornments
into images required for training of our GAN.
"""
def __init__(self, *args):
super(InputTransformation, self).__init__(*args)
# Assert that the observation space is not discrete but
# bounded.
assert isinstance(self.observation_space, gym.spaces.Box)
obs_space = self.observation_space
self.observation_space = gym.spaces.Box(
self.observation(obs_space.low),
self.observation(obs_space.high),
dtype=np.float32)
def observation(self, observation):
"""
1. Resize the image according to image_size specified.
2. Transform the dimensions of observation from
`(height, width, channels)` to
`(channels, height, width)`.
"""
result = cv2.resize(observation, IMAGE_SIZE)
result = np.moveaxis(result, 2, 0)
result = result.astype(np.float32) / 255.0
return result