-
Notifications
You must be signed in to change notification settings - Fork 0
/
environments.py
169 lines (155 loc) · 5.91 KB
/
environments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import numpy as np
import gym
import os
from mlagents_envs.environment import UnityEnvironment
from gym_unity.envs import UnityToGymWrapper
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel
import time
class SeqEnv:
def __init__(self, env_name, seed=int(time.time()),
worker_id=None,
float_params=dict(),
**kwargs):
"""
env_name: str
the name of the environment
seed: int
the random seed for the environment
worker_id: int
must specify a unique worker id for each unity process
on this machine
float_params: dict or None
this should be a dict of argument settings for the unity
environment
keys: varies by environment
"""
self.env_name = env_name
self.seed = seed
self.worker_id = worker_id
self.float_params = float_params
try:
self.env = gym.make(env_name)
self.env.seed(seed)
self.is_gym = True
except Exception as e:
self.env = UnityGymEnv(env_name=self.env_name,
seed=self.seed,
worker_id=self.worker_id,
float_params=self.float_params)
self.is_gym = False
self.action_space = self.env.action_space
def reset(self):
return self.env.reset()
def step(self, action):
return self.env.step(action)
class UnityGymEnv:
def __init__(self, env_name, seed=int(time.time()),
worker_id=None,
float_params=dict(),
**kwargs):
"""
env_name: str
the name of the environment
seed: int
the random seed for the environment
worker_id: int
must specify a unique worker id for each unity process
on this machine
float_params: dict or None
this should be a dict of argument settings for the unity
environment
keys: varies by environment
"""
self.env_name = env_name
self.seed = seed
self.worker_id = worker_id
self.float_params = float_params
self.env = self.make_unity_env(env_name,
seed=self.seed,
worker_id=self.worker_id,
float_params=float_params,
**kwargs)
obs = self.reset()
self.shape = obs.shape
self.is_discrete = False
self.action_space = np.zeros((2,))
def prep_obs(self, obs):
"""
obs: list or ndarray
the observation returned by the environment
"""
if not isinstance(obs, list): return obs
obs = np.asarray(obs[0])
info = [*obs[1:]]
return obs, info
def reset(self):
obs = self.env.reset()
obs,_ = self.prep_obs(obs)
return obs
def step(self,action):
"""
action: ndarray (SHAPE = self.action_space.shape)
the action to take in this step. type can vary depending
on the environment type
"""
obs,rew,done,info = self.env.step(action.squeeze())
obs,targ = self.prep_obs(obs)
targ[:2] = np.clip(targ[:2],-1,1)
return obs, rew, done, targ
def render(self):
return None
def close(self):
self.env.close()
def make_unity_env(self, env_name, float_params=dict(), time_scale=1,
seed=time.time(),
worker_id=None,
**kwargs):
"""
creates a gym environment from a unity game
env_name: str
the path to the game
float_params: dict or None
this should be a dict of argument settings for the unity
environment
keys: varies by environment
time_scale: float
argument to set Unity's time scale. This applies less to
gym wrapped versions of Unity Environments, I believe..
but I'm not sure
seed: int
the seed for randomness
worker_id: int
must specify a unique worker id for each unity process
on this machine
"""
if float_params is None: float_params = dict()
path = os.path.expanduser(env_name)
channel = EngineConfigurationChannel()
env_channel = EnvironmentParametersChannel()
channel.set_configuration_parameters(time_scale = 1)
for k,v in float_params.items():
if k=="validation" and v>=1:
print("Game in validation mode")
env_channel.set_float_parameter(k, float(v))
if worker_id is None: worker_id = seed%500+1
env_made = False
n_loops = 0
worker_id = 0
while not env_made and n_loops < 50:
try:
env = UnityEnvironment(file_name=path,
side_channels=[channel,env_channel],
worker_id=worker_id,
seed=seed)
env_made = True
except:
s = "Error encountered making environment, "
s += "trying new worker_id"
print(s)
worker_id =(worker_id+1+int(np.random.random()*100))%500
try: env.close()
except: pass
n_loops += 1
env = UnityToGymWrapper(env, allow_multiple_obs=True)
return env