-
Notifications
You must be signed in to change notification settings - Fork 38
/
predict_env.py
118 lines (95 loc) · 4.28 KB
/
predict_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
class PredictEnv:
def __init__(self, model, env_name, model_type):
self.model = model
self.env_name = env_name
self.model_type = model_type
def _termination_fn(self, env_name, obs, act, next_obs):
# TODO
if env_name == "Hopper-v2":
assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
height = next_obs[:, 0]
angle = next_obs[:, 1]
not_done = np.isfinite(next_obs).all(axis=-1) \
* np.abs(next_obs[:, 1:] < 100).all(axis=-1) \
* (height > .7) \
* (np.abs(angle) < .2)
done = ~not_done
done = done[:, None]
return done
elif env_name == "Walker2d-v2":
assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
height = next_obs[:, 0]
angle = next_obs[:, 1]
not_done = (height > 0.8) \
* (height < 2.0) \
* (angle > -1.0) \
* (angle < 1.0)
done = ~not_done
done = done[:, None]
return done
elif 'walker_' in env_name:
torso_height = next_obs[:, -2]
torso_ang = next_obs[:, -1]
if 'walker_7' in env_name or 'walker_5' in env_name:
offset = 0.
else:
offset = 0.26
not_done = (torso_height > 0.8 - offset) \
* (torso_height < 2.0 - offset) \
* (torso_ang > -1.0) \
* (torso_ang < 1.0)
done = ~not_done
done = done[:, None]
return done
def _get_logprob(self, x, means, variances):
k = x.shape[-1]
## [ num_networks, batch_size ]
log_prob = -1 / 2 * (k * np.log(2 * np.pi) + np.log(variances).sum(-1) + (np.power(x - means, 2) / variances).sum(-1))
## [ batch_size ]
prob = np.exp(log_prob).sum(0)
## [ batch_size ]
log_prob = np.log(prob)
stds = np.std(means, 0).mean(-1)
return log_prob, stds
def step(self, obs, act, deterministic=False):
if len(obs.shape) == 1:
obs = obs[None]
act = act[None]
return_single = True
else:
return_single = False
inputs = np.concatenate((obs, act), axis=-1)
if self.model_type == 'pytorch':
ensemble_model_means, ensemble_model_vars = self.model.predict(inputs)
else:
ensemble_model_means, ensemble_model_vars = self.model.predict(inputs, factored=True)
ensemble_model_means[:, :, 1:] += obs
ensemble_model_stds = np.sqrt(ensemble_model_vars)
if deterministic:
ensemble_samples = ensemble_model_means
else:
ensemble_samples = ensemble_model_means + np.random.normal(size=ensemble_model_means.shape) * ensemble_model_stds
num_models, batch_size, _ = ensemble_model_means.shape
if self.model_type == 'pytorch':
model_idxes = np.random.choice(self.model.elite_model_idxes, size=batch_size)
else:
model_idxes = self.model.random_inds(batch_size)
batch_idxes = np.arange(0, batch_size)
samples = ensemble_samples[model_idxes, batch_idxes]
model_means = ensemble_model_means[model_idxes, batch_idxes]
model_stds = ensemble_model_stds[model_idxes, batch_idxes]
log_prob, dev = self._get_logprob(samples, ensemble_model_means, ensemble_model_vars)
rewards, next_obs = samples[:, :1], samples[:, 1:]
terminals = self._termination_fn(self.env_name, obs, act, next_obs)
batch_size = model_means.shape[0]
return_means = np.concatenate((model_means[:, :1], terminals, model_means[:, 1:]), axis=-1)
return_stds = np.concatenate((model_stds[:, :1], np.zeros((batch_size, 1)), model_stds[:, 1:]), axis=-1)
if return_single:
next_obs = next_obs[0]
return_means = return_means[0]
return_stds = return_stds[0]
rewards = rewards[0]
terminals = terminals[0]
info = {'mean': return_means, 'std': return_stds, 'log_prob': log_prob, 'dev': dev}
return next_obs, rewards, terminals, info