Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.

Commit 092d425

Browse files
author
Omegastick
committed
Add ObservationNormalizer
1 parent e9ebdf3 commit 092d425

File tree

6 files changed

+361
-12
lines changed

6 files changed

+361
-12
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
#include <torch/torch.h>
6+
7+
#include "cpprl/running_mean_std.h"
8+
9+
namespace SingularityTrainer
10+
{
11+
class ObservationNormalizer
12+
{
13+
private:
14+
float clip;
15+
RunningMeanStd rms;
16+
17+
public:
18+
explicit ObservationNormalizer(int size, float clip = 10.);
19+
ObservationNormalizer(const std::vector<float> &means,
20+
const std::vector<float> &variances,
21+
float clip = 10.);
22+
explicit ObservationNormalizer(const std::vector<ObservationNormalizer> &others);
23+
24+
torch::Tensor process_observation(torch::Tensor observation);
25+
std::vector<float> get_mean() const;
26+
std::vector<float> get_variance() const;
27+
void update(torch::Tensor observations);
28+
29+
inline float get_clip_value() const { return clip; }
30+
inline int get_step_count() const { return rms.get_count(); }
31+
};
32+
}

include/cpprl/running_mean_std.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
#include <torch/torch.h>
6+
7+
namespace SingularityTrainer
8+
{
9+
// https://github.com/openai/baselines/blob/master/baselines/common/running_mean_std.py
10+
class RunningMeanStd
11+
{
12+
private:
13+
float count;
14+
torch::Tensor mean, variance;
15+
16+
void update_from_moments(torch::Tensor batch_mean,
17+
torch::Tensor batch_var,
18+
int batch_count);
19+
20+
public:
21+
explicit RunningMeanStd(int size);
22+
RunningMeanStd(std::vector<float> means, std::vector<float> variances);
23+
24+
void update(torch::Tensor observation);
25+
26+
inline int get_count() const { return static_cast<int>(count); }
27+
inline torch::Tensor get_mean() const { return mean.clone(); }
28+
inline torch::Tensor get_variance() const { return variance.clone(); }
29+
inline void set_count(int count) { this->count = count + 1e-8; }
30+
};
31+
}

src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
target_sources(cpprl
22
PRIVATE
33
${CMAKE_CURRENT_LIST_DIR}/storage.cpp
4+
${CMAKE_CURRENT_LIST_DIR}/observation_normalizer.cpp
5+
${CMAKE_CURRENT_LIST_DIR}/running_mean_std.cpp
46
)
57

68
if (CPPRL_BUILD_TESTS)
79
target_sources(cpprl_tests
810
PRIVATE
911
${CMAKE_CURRENT_LIST_DIR}/storage.cpp
12+
${CMAKE_CURRENT_LIST_DIR}/observation_normalizer.cpp
13+
${CMAKE_CURRENT_LIST_DIR}/running_mean_std.cpp
1014
)
1115
endif (CPPRL_BUILD_TESTS)
1216

src/observation_normalizer.cpp

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
#include <torch/torch.h>
2+
3+
#include "cpprl/observation_normalizer.h"
4+
#include "cpprl/running_mean_std.h"
5+
#include "third_party/doctest.h"
6+
7+
namespace SingularityTrainer
8+
{
9+
ObservationNormalizer::ObservationNormalizer(int size, float clip)
10+
: clip(clip),
11+
rms(size) {}
12+
13+
ObservationNormalizer::ObservationNormalizer(const std::vector<float> &means,
14+
const std::vector<float> &variances,
15+
float clip)
16+
: clip(clip),
17+
rms(means, variances){}
18+
19+
ObservationNormalizer::ObservationNormalizer(const std::vector<ObservationNormalizer> &others)
20+
: clip(0),
21+
rms(1)
22+
{
23+
// Calculate mean clip
24+
for (const auto &other : others)
25+
{
26+
clip += other.get_clip_value();
27+
}
28+
clip /= others.size();
29+
30+
// Calculate mean mean
31+
std::vector<float> mean_means(others[0].get_mean().size(), 0);
32+
for (const auto &other : others)
33+
{
34+
auto other_mean = other.get_mean();
35+
for (unsigned int i = 0; i < mean_means.size(); ++i)
36+
{
37+
mean_means[i] += other_mean[i];
38+
}
39+
}
40+
for (auto &mean : mean_means)
41+
{
42+
mean /= others.size();
43+
}
44+
45+
// Calculate mean variances
46+
std::vector<float> mean_variances(others[0].get_variance().size(), 0);
47+
for (const auto &other : others)
48+
{
49+
auto other_variances = other.get_variance();
50+
for (unsigned int i = 0; i < mean_variances.size(); ++i)
51+
{
52+
mean_variances[i] += other_variances[i];
53+
}
54+
}
55+
for (auto &variance : mean_variances)
56+
{
57+
variance /= others.size();
58+
}
59+
60+
rms = RunningMeanStd(mean_means, mean_variances);
61+
62+
int total_count = std::accumulate(others.begin(), others.end(), 0,
63+
[](int accumulator, const ObservationNormalizer &other) {
64+
return accumulator + other.get_step_count();
65+
});
66+
rms.set_count(total_count);
67+
}
68+
69+
torch::Tensor ObservationNormalizer::process_observation(torch::Tensor observation)
70+
{
71+
auto normalized_obs = (observation - rms.get_mean()) /
72+
torch::sqrt(rms.get_variance() + 1e-8);
73+
return torch::clamp(normalized_obs, -clip, clip);
74+
}
75+
76+
std::vector<float> ObservationNormalizer::get_mean() const
77+
{
78+
auto mean = rms.get_mean();
79+
return std::vector<float>(mean.data<float>(), mean.data<float>() + mean.numel());
80+
}
81+
82+
std::vector<float> ObservationNormalizer::get_variance() const
83+
{
84+
auto variance = rms.get_variance();
85+
return std::vector<float>(variance.data<float>(), variance.data<float>() + variance.numel());
86+
}
87+
88+
void ObservationNormalizer::update(torch::Tensor observations)
89+
{
90+
rms.update(observations);
91+
}
92+
93+
TEST_CASE("ObservationNormalizer")
94+
{
95+
SUBCASE("Clips values correctly")
96+
{
97+
ObservationNormalizer normalizer(7, 1);
98+
float observation_array[] = {-1000, -100, -10, 0, 10, 100, 1000};
99+
auto observation = torch::from_blob(observation_array, {7});
100+
auto processed_observation = normalizer.process_observation(observation);
101+
102+
auto has_too_large_values = (processed_observation > 1).any().item().toBool();
103+
auto has_too_small_values = (processed_observation < -1).any().item().toBool();
104+
DOCTEST_CHECK(!has_too_large_values);
105+
DOCTEST_CHECK(!has_too_small_values);
106+
}
107+
108+
SUBCASE("Normalizes values correctly")
109+
{
110+
ObservationNormalizer normalizer(5);
111+
112+
float obs_1_array[] = {-10., 0., 5., 3.2, 0.};
113+
float obs_2_array[] = {-5., 2., 4., 3.7, -3.};
114+
float obs_3_array[] = {1, 2, 3, 4, 5};
115+
auto obs_1 = torch::from_blob(obs_1_array, {5});
116+
auto obs_2 = torch::from_blob(obs_2_array, {5});
117+
auto obs_3 = torch::from_blob(obs_3_array, {5});
118+
119+
normalizer.update(obs_1);
120+
normalizer.update(obs_2);
121+
normalizer.update(obs_3);
122+
auto processed_observation = normalizer.process_observation(obs_3);
123+
124+
DOCTEST_CHECK(processed_observation[0].item().toFloat() == doctest::Approx(1.26008659));
125+
DOCTEST_CHECK(processed_observation[1].item().toFloat() == doctest::Approx(0.70712887));
126+
DOCTEST_CHECK(processed_observation[2].item().toFloat() == doctest::Approx(-1.2240818));
127+
DOCTEST_CHECK(processed_observation[3].item().toFloat() == doctest::Approx(1.10914509));
128+
DOCTEST_CHECK(processed_observation[4].item().toFloat() == doctest::Approx(1.31322402));
129+
}
130+
131+
SUBCASE("Loads mean and variance from constructor correctly")
132+
{
133+
ObservationNormalizer normalizer({1, 2, 3}, {4, 5, 6});
134+
135+
auto mean = normalizer.get_mean();
136+
auto variance = normalizer.get_variance();
137+
DOCTEST_CHECK(mean[0] == doctest::Approx(1));
138+
DOCTEST_CHECK(mean[1] == doctest::Approx(2));
139+
DOCTEST_CHECK(mean[2] == doctest::Approx(3));
140+
DOCTEST_CHECK(variance[0] == doctest::Approx(4));
141+
DOCTEST_CHECK(variance[1] == doctest::Approx(5));
142+
DOCTEST_CHECK(variance[2] == doctest::Approx(6));
143+
}
144+
145+
SUBCASE("Is constructed from other normalizers correctly")
146+
{
147+
std::vector<ObservationNormalizer> normalizers;
148+
for (int i = 0; i < 3; ++i)
149+
{
150+
normalizers.push_back(ObservationNormalizer(3));
151+
for (int j = 0; j <= i; ++j)
152+
{
153+
normalizers[i].update(torch::rand({3}));
154+
}
155+
}
156+
157+
ObservationNormalizer combined_normalizer(normalizers);
158+
159+
std::vector<std::vector<float>> means;
160+
std::transform(normalizers.begin(), normalizers.end(), std::back_inserter(means),
161+
[](const ObservationNormalizer &normalizer) { return normalizer.get_mean(); });
162+
std::vector<std::vector<float>> variances;
163+
std::transform(normalizers.begin(), normalizers.end(), std::back_inserter(variances),
164+
[](const ObservationNormalizer &normalizer) { return normalizer.get_variance(); });
165+
166+
std::vector<float> mean_means;
167+
for (int i = 0; i < 3; ++i)
168+
{
169+
mean_means.push_back((means[0][i] + means[1][i] + means[2][i]) / 3);
170+
}
171+
std::vector<float> mean_variances;
172+
for (int i = 0; i < 3; ++i)
173+
{
174+
mean_variances.push_back((variances[0][i] + variances[1][i] + variances[2][i]) / 3);
175+
}
176+
177+
auto actual_mean_means = combined_normalizer.get_mean();
178+
auto actual_mean_variances = combined_normalizer.get_variance();
179+
180+
for (int i = 0; i < 3; ++i)
181+
{
182+
DOCTEST_CHECK(actual_mean_means[i] == doctest::Approx(mean_means[i]));
183+
DOCTEST_CHECK(actual_mean_variances[i] == doctest::Approx(actual_mean_variances[i]));
184+
}
185+
DOCTEST_CHECK(combined_normalizer.get_step_count() == 6);
186+
}
187+
}
188+
}

src/running_mean_std.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include <torch/torch.h>
2+
3+
#include "cpprl/running_mean_std.h"
4+
#include "third_party/doctest.h"
5+
6+
namespace SingularityTrainer
7+
{
8+
RunningMeanStd::RunningMeanStd(int size)
9+
: count(1e-4),
10+
mean(torch::zeros({size})),
11+
variance(torch::ones({size})) {}
12+
13+
RunningMeanStd::RunningMeanStd(std::vector<float> means, std::vector<float> variances)
14+
: count(1e-4),
15+
mean(torch::from_blob(means.data(), {static_cast<long>(means.size())})
16+
.clone()),
17+
variance(torch::from_blob(variances.data(), {static_cast<long>(variances.size())})
18+
.clone()) {}
19+
20+
void RunningMeanStd::update(torch::Tensor observation)
21+
{
22+
observation = observation.reshape({-1, mean.size(0)});
23+
auto batch_mean = observation.mean(0);
24+
auto batch_var = observation.var(0, false, false);
25+
auto batch_count = observation.size(0);
26+
27+
update_from_moments(batch_mean, batch_var, batch_count);
28+
}
29+
30+
void RunningMeanStd::update_from_moments(torch::Tensor batch_mean,
31+
torch::Tensor batch_var,
32+
int batch_count)
33+
{
34+
auto delta = batch_mean - mean;
35+
float total_count = count + batch_count;
36+
37+
mean = mean + delta * batch_count / total_count;
38+
auto m_a = variance * count;
39+
auto m_b = batch_var * batch_count;
40+
auto m2 = m_a + m_b + torch::pow(delta, 2) * count * batch_count / total_count;
41+
variance = m2 / total_count;
42+
count = total_count;
43+
}
44+
45+
TEST_CASE("RunningMeanStd")
46+
{
47+
SUBCASE("Calculates mean and variance correctly")
48+
{
49+
RunningMeanStd rms(5);
50+
auto observations = torch::rand({3, 5});
51+
rms.update(observations[0]);
52+
rms.update(observations[1]);
53+
rms.update(observations[2]);
54+
55+
auto expected_mean = observations.mean(0);
56+
auto expected_variance = observations.var(0, false, false);
57+
58+
auto actual_mean = rms.get_mean();
59+
auto actual_variance = rms.get_variance();
60+
61+
for (int i = 0; i < 5; ++i)
62+
{
63+
DOCTEST_CHECK(expected_mean[i].item().toFloat() ==
64+
doctest::Approx(actual_mean[i].item().toFloat())
65+
.epsilon(0.001));
66+
DOCTEST_CHECK(expected_variance[i].item().toFloat() ==
67+
doctest::Approx(actual_variance[i].item().toFloat())
68+
.epsilon(0.001));
69+
}
70+
}
71+
72+
SUBCASE("Loads mean and variance from constructor correctly")
73+
{
74+
RunningMeanStd rms({1, 2, 3}, {4, 5, 6});
75+
76+
auto mean = rms.get_mean();
77+
auto variance = rms.get_variance();
78+
DOCTEST_CHECK(mean[0].item().toFloat() == doctest::Approx(1));
79+
DOCTEST_CHECK(mean[1].item().toFloat() == doctest::Approx(2));
80+
DOCTEST_CHECK(mean[2].item().toFloat() == doctest::Approx(3));
81+
DOCTEST_CHECK(variance[0].item().toFloat() == doctest::Approx(4));
82+
DOCTEST_CHECK(variance[1].item().toFloat() == doctest::Approx(5));
83+
DOCTEST_CHECK(variance[2].item().toFloat() == doctest::Approx(6));
84+
}
85+
}
86+
}

0 commit comments

Comments
 (0)