Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.

Commit ae03039

Browse files
author
Omegastick
committed
Implement automatic observation normalization
1 parent 0615f76 commit ae03039

File tree

17 files changed

+442
-251
lines changed

17 files changed

+442
-251
lines changed

example/gym_client.cpp

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,27 @@ using namespace cpprl;
1616

1717
// Algorithm hyperparameters
1818
const std::string algorithm = "PPO";
19-
const int batch_size = 2048;
19+
const float actor_loss_coef = 1.0;
20+
const int batch_size = 40;
2021
const float clip_param = 0.2;
2122
const float discount_factor = 0.99;
22-
const float entropy_coef = 0.001;
23+
const float entropy_coef = 1e-3;
2324
const float gae = 0.95;
24-
const float learning_rate = 2.5e-4;
25+
const float kl_target = 0.05;
26+
const float learning_rate = 7e-4;
2527
const int log_interval = 1;
2628
const int max_frames = 10e+7;
27-
const int num_epoch = 10;
28-
const int num_mini_batch = 32;
29+
const int num_epoch = 3;
30+
const int num_mini_batch = 20;
2931
const int reward_average_window_size = 10;
32+
const float reward_clip_value = 10; // Post scaling
3033
const bool use_gae = true;
31-
const bool use_lr_decay = true;
32-
const float actor_loss_coef = 1.0;
34+
const bool use_lr_decay = false;
3335
const float value_loss_coef = 0.5;
3436

3537
// Environment hyperparameters
36-
const float env_gamma = discount_factor; // Set to -1 to disable
37-
const std::string env_name = "BipedalWalkerHardcore-v2";
38-
const int num_envs = 16;
38+
const std::string env_name = "LunarLander-v2";
39+
const int num_envs = 8;
3940
const float render_reward_threshold = 160;
4041

4142
// Model hyperparameters
@@ -80,7 +81,6 @@ int main(int argc, char *argv[])
8081
spdlog::info("Creating environment");
8182
auto make_param = std::make_shared<MakeParam>();
8283
make_param->env_name = env_name;
83-
make_param->gamma = env_gamma;
8484
make_param->num_envs = num_envs;
8585
Request<MakeParam> make_request("make", make_param);
8686
communicator.send_request(make_request);
@@ -125,7 +125,17 @@ int main(int argc, char *argv[])
125125
}
126126
base->to(device);
127127
ActionSpace space{env_info->action_space_type, env_info->action_space_shape};
128-
Policy policy(space, base);
128+
Policy policy(nullptr);
129+
if (env_info->observation_space_shape.size() == 1)
130+
{
131+
// With observation normalization
132+
policy = Policy(space, base, true);
133+
}
134+
else
135+
{
136+
// Without observation normalization
137+
policy = Policy(space, base, true);
138+
}
129139
policy->to(device);
130140
RolloutStorage storage(batch_size, num_envs, env_info->observation_space_shape, space, hidden_size, device);
131141
std::unique_ptr<Algorithm> algo;
@@ -135,7 +145,17 @@ int main(int argc, char *argv[])
135145
}
136146
else if (algorithm == "PPO")
137147
{
138-
algo = std::make_unique<PPO>(policy, clip_param, num_epoch, num_mini_batch, actor_loss_coef, value_loss_coef, entropy_coef, learning_rate);
148+
algo = std::make_unique<PPO>(policy,
149+
clip_param,
150+
num_epoch,
151+
num_mini_batch,
152+
actor_loss_coef,
153+
value_loss_coef,
154+
entropy_coef,
155+
learning_rate,
156+
1e-8,
157+
0.5,
158+
kl_target);
139159
}
140160

141161
storage.set_first_observation(observation);
@@ -144,6 +164,8 @@ int main(int argc, char *argv[])
144164
int episode_count = 0;
145165
bool render = false;
146166
std::vector<float> reward_history(reward_average_window_size);
167+
RunningMeanStd returns_rms(1);
168+
auto returns = torch::zeros({num_envs});
147169

148170
auto start_time = std::chrono::high_resolution_clock::now();
149171

@@ -159,14 +181,21 @@ int main(int argc, char *argv[])
159181
storage.get_hidden_states()[step],
160182
storage.get_masks()[step]);
161183
}
162-
auto actions_tensor = act_result[1].cpu();
184+
auto actions_tensor = act_result[1].cpu().to(torch::kFloat);
163185
float *actions_array = actions_tensor.data<float>();
164186
std::vector<std::vector<float>> actions(num_envs);
165187
for (int i = 0; i < num_envs; ++i)
166188
{
167-
for (int j = 0; j < env_info->action_space_shape[0]; j++)
189+
if (space.type == "Discrete")
190+
{
191+
actions[i] = {actions_array[i]};
192+
}
193+
else
168194
{
169-
actions[i].push_back(actions_array[i * env_info->action_space_shape[0] + j]);
195+
for (int j = 0; j < env_info->action_space_shape[0]; j++)
196+
{
197+
actions[i].push_back(actions_array[i * env_info->action_space_shape[0] + j]);
198+
}
170199
}
171200
}
172201

@@ -183,7 +212,13 @@ int main(int argc, char *argv[])
183212
auto step_result = communicator.get_response<CnnStepResponse>();
184213
observation_vec = flatten_vector(step_result->observation);
185214
observation = torch::from_blob(observation_vec.data(), observation_shape).to(device);
186-
rewards = flatten_vector(step_result->reward);
215+
auto raw_reward_vec = flatten_vector(step_result->real_reward);
216+
auto reward_tensor = torch::from_blob(raw_reward_vec.data(), {num_envs}, torch::kFloat);
217+
returns = returns * discount_factor + reward_tensor;
218+
returns_rms->update(returns);
219+
reward_tensor = torch::clamp(reward_tensor / torch::sqrt(returns_rms->get_variance() + 1e-8),
220+
-reward_clip_value, reward_clip_value);
221+
rewards = std::vector<float>(reward_tensor.data<float>(), reward_tensor.data<float>() + reward_tensor.numel());
187222
real_rewards = flatten_vector(step_result->real_reward);
188223
dones_vec = step_result->done;
189224
}
@@ -192,7 +227,13 @@ int main(int argc, char *argv[])
192227
auto step_result = communicator.get_response<MlpStepResponse>();
193228
observation_vec = flatten_vector(step_result->observation);
194229
observation = torch::from_blob(observation_vec.data(), observation_shape).to(device);
195-
rewards = flatten_vector(step_result->reward);
230+
auto raw_reward_vec = flatten_vector(step_result->real_reward);
231+
auto reward_tensor = torch::from_blob(raw_reward_vec.data(), {num_envs}, torch::kFloat);
232+
returns = returns * discount_factor + reward_tensor;
233+
returns_rms->update(returns);
234+
reward_tensor = torch::clamp(reward_tensor / torch::sqrt(returns_rms->get_variance() + 1e-8),
235+
-reward_clip_value, reward_clip_value);
236+
rewards = std::vector<float>(reward_tensor.data<float>(), reward_tensor.data<float>() + reward_tensor.numel());
196237
real_rewards = flatten_vector(step_result->real_reward);
197238
dones_vec = step_result->done;
198239
}
@@ -203,6 +244,7 @@ int main(int argc, char *argv[])
203244
{
204245
reward_history[episode_count % reward_average_window_size] = running_rewards[i];
205246
running_rewards[i] = 0;
247+
returns[i] = 0;
206248
episode_count++;
207249
}
208250
}

example/requests.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ struct InfoParam
2525
struct MakeParam
2626
{
2727
std::string env_name;
28-
float gamma;
2928
int num_envs;
30-
MSGPACK_DEFINE_MAP(env_name, gamma, num_envs);
29+
MSGPACK_DEFINE_MAP(env_name, num_envs);
3130
};
3231

3332
struct ResetParam

gym_server/envs.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,21 +138,15 @@ def _thunk():
138138
return _thunk
139139

140140

141-
def make_vec_envs(env_name, seed, num_processes, gamma, num_frame_stack=None):
141+
def make_vec_envs(env_name, seed, num_processes, num_frame_stack=None):
142142
envs = [make_env(env_name, seed, i) for i in range(num_processes)]
143143

144144
if len(envs) > 1:
145145
envs = SubprocVecEnv(envs)
146146
else:
147147
envs = DummyVecEnv(envs)
148148

149-
if len(envs.observation_space.shape) == 1:
150-
if gamma is None or gamma == -1:
151-
envs = VecNormalize(envs, ret=False)
152-
else:
153-
envs = VecNormalize(envs, gamma=gamma)
154-
else:
155-
envs = VecRewardInfo(envs)
149+
envs = VecRewardInfo(envs)
156150

157151
if num_frame_stack is not None:
158152
envs = VecFrameStack(envs, num_frame_stack)

gym_server/server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def _serve(self):
5353
observation_space_shape))
5454

5555
elif method == 'make':
56-
self.__make(param['env_name'], param['num_envs'],
57-
param['gamma'])
56+
self.__make(param['env_name'], param['num_envs'])
5857
self.zmq_client.send(MakeMessage())
5958

6059
elif method == 'reset':
@@ -86,12 +85,12 @@ def info(self):
8685
return (action_space_type, action_space_shape, observation_space_type,
8786
observation_space_shape)
8887

89-
def make(self, env_name, num_envs, gamma):
88+
def make(self, env_name, num_envs):
9089
"""
9190
Makes a vectorized environment of the type and number specified.
9291
"""
9392
logging.info("Making %d %ss", num_envs, env_name)
94-
self.env = make_vec_envs(env_name, 0, num_envs, gamma)
93+
self.env = make_vec_envs(env_name, 0, num_envs)
9594

9695
def reset(self) -> np.ndarray:
9796
"""
@@ -109,6 +108,7 @@ def step(self,
109108
"""
110109
if isinstance(self.env.action_space, gym.spaces.Discrete):
111110
actions = actions.squeeze(-1)
111+
actions = actions.astype(np.int)
112112
observation, reward, done, info = self.env.step(actions)
113113
reward = np.expand_dims(reward, -1)
114114
done = np.expand_dims(done, -1)

include/cpprl/cpprl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
#include "cpprl/model/nn_base.h"
1212
#include "cpprl/model/output_layers.h"
1313
#include "cpprl/model/policy.h"
14+
#include "cpprl/observation_normalizer.h"
1415
#include "cpprl/spaces.h"
1516
#include "cpprl/storage.h"

include/cpprl/model/mlp_base.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class MlpBase : public NNBase
1616
nn::Sequential actor;
1717
nn::Sequential critic;
1818
nn::Linear critic_linear;
19+
unsigned int num_inputs;
1920

2021
public:
2122
MlpBase(unsigned int num_inputs,
@@ -25,5 +26,7 @@ class MlpBase : public NNBase
2526
std::vector<torch::Tensor> forward(torch::Tensor inputs,
2627
torch::Tensor hxs,
2728
torch::Tensor masks);
29+
30+
inline unsigned int get_num_inputs() const { return num_inputs; }
2831
};
2932
}

include/cpprl/model/nn_base.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ namespace cpprl
1111
class NNBase : public nn::Module
1212
{
1313
private:
14-
bool recurrent;
15-
unsigned int hidden_size;
1614
nn::GRU gru;
15+
unsigned int hidden_size;
16+
bool recurrent;
1717

1818
public:
1919
NNBase(bool recurrent,

include/cpprl/model/policy.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "cpprl/model/nn_base.h"
99
#include "cpprl/model/output_layers.h"
10+
#include "cpprl/observation_normalizer.h"
1011
#include "cpprl/spaces.h"
1112

1213
using namespace torch;
@@ -16,16 +17,19 @@ namespace cpprl
1617
class PolicyImpl : public nn::Module
1718
{
1819
private:
20+
ActionSpace action_space;
1921
std::shared_ptr<NNBase> base;
22+
ObservationNormalizer observation_normalizer;
2023
std::shared_ptr<OutputLayer> output_layer;
21-
ActionSpace action_space;
2224

2325
std::vector<torch::Tensor> forward_gru(torch::Tensor x,
2426
torch::Tensor hxs,
2527
torch::Tensor masks);
2628

2729
public:
28-
PolicyImpl(ActionSpace action_space, std::shared_ptr<NNBase> base);
30+
PolicyImpl(ActionSpace action_space,
31+
std::shared_ptr<NNBase> base,
32+
bool normalize_observations = false);
2933

3034
std::vector<torch::Tensor> act(torch::Tensor inputs,
3135
torch::Tensor rnn_hxs,
@@ -40,12 +44,17 @@ class PolicyImpl : public nn::Module
4044
torch::Tensor get_values(torch::Tensor inputs,
4145
torch::Tensor rnn_hxs,
4246
torch::Tensor masks);
47+
void update_observation_normalizer(torch::Tensor observations);
4348

4449
inline bool is_recurrent() const { return base->is_recurrent(); }
4550
inline unsigned int get_hidden_size() const
4651
{
4752
return base->get_hidden_size();
4853
}
54+
inline bool using_observation_normalizer() const
55+
{
56+
return !observation_normalizer.is_empty();
57+
}
4958
};
5059
TORCH_MODULE(Policy);
5160
}

include/cpprl/observation_normalizer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#include "cpprl/running_mean_std.h"
88

9-
namespace SingularityTrainer
9+
namespace cpprl
1010
{
1111
class ObservationNormalizer;
1212

include/cpprl/running_mean_std.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include <torch/torch.h>
66

7-
namespace SingularityTrainer
7+
namespace cpprl
88
{
99
// https://github.com/openai/baselines/blob/master/baselines/common/running_mean_std.py
1010
class RunningMeanStdImpl : public torch::nn::Module

0 commit comments

Comments
 (0)