Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.

Commit 0615f76

Browse files
author
Omegastick
committed
Convert ObservationNormalizer and RunningMeanStd to Torch modules
1 parent 092d425 commit 0615f76

File tree

4 files changed

+87
-84
lines changed

4 files changed

+87
-84
lines changed

include/cpprl/observation_normalizer.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,28 @@
88

99
namespace SingularityTrainer
1010
{
11-
class ObservationNormalizer
11+
class ObservationNormalizer;
12+
13+
class ObservationNormalizerImpl : public torch::nn::Module
1214
{
1315
private:
14-
float clip;
16+
torch::Tensor clip;
1517
RunningMeanStd rms;
1618

1719
public:
18-
explicit ObservationNormalizer(int size, float clip = 10.);
19-
ObservationNormalizer(const std::vector<float> &means,
20-
const std::vector<float> &variances,
21-
float clip = 10.);
22-
explicit ObservationNormalizer(const std::vector<ObservationNormalizer> &others);
20+
explicit ObservationNormalizerImpl(int size, float clip = 10.);
21+
ObservationNormalizerImpl(const std::vector<float> &means,
22+
const std::vector<float> &variances,
23+
float clip = 10.);
24+
explicit ObservationNormalizerImpl(const std::vector<ObservationNormalizer> &others);
2325

2426
torch::Tensor process_observation(torch::Tensor observation);
2527
std::vector<float> get_mean() const;
2628
std::vector<float> get_variance() const;
2729
void update(torch::Tensor observations);
2830

29-
inline float get_clip_value() const { return clip; }
30-
inline int get_step_count() const { return rms.get_count(); }
31+
inline float get_clip_value() const { return clip.item().toFloat(); }
32+
inline int get_step_count() const { return rms->get_count(); }
3133
};
34+
TORCH_MODULE(ObservationNormalizer);
3235
}

include/cpprl/running_mean_std.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,25 @@
77
namespace SingularityTrainer
88
{
99
// https://github.com/openai/baselines/blob/master/baselines/common/running_mean_std.py
10-
class RunningMeanStd
10+
class RunningMeanStdImpl : public torch::nn::Module
1111
{
1212
private:
13-
float count;
14-
torch::Tensor mean, variance;
13+
torch::Tensor count, mean, variance;
1514

1615
void update_from_moments(torch::Tensor batch_mean,
1716
torch::Tensor batch_var,
1817
int batch_count);
1918

2019
public:
21-
explicit RunningMeanStd(int size);
22-
RunningMeanStd(std::vector<float> means, std::vector<float> variances);
20+
explicit RunningMeanStdImpl(int size);
21+
RunningMeanStdImpl(std::vector<float> means, std::vector<float> variances);
2322

2423
void update(torch::Tensor observation);
2524

26-
inline int get_count() const { return static_cast<int>(count); }
25+
inline int get_count() const { return static_cast<int>(count.item().toFloat()); }
2726
inline torch::Tensor get_mean() const { return mean.clone(); }
2827
inline torch::Tensor get_variance() const { return variance.clone(); }
29-
inline void set_count(int count) { this->count = count + 1e-8; }
28+
inline void set_count(int count) { this->count[0] = count + 1e-8; }
3029
};
30+
TORCH_MODULE(RunningMeanStd);
3131
}

src/observation_normalizer.cpp

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,32 @@
66

77
namespace SingularityTrainer
88
{
9-
ObservationNormalizer::ObservationNormalizer(int size, float clip)
10-
: clip(clip),
11-
rms(size) {}
12-
13-
ObservationNormalizer::ObservationNormalizer(const std::vector<float> &means,
14-
const std::vector<float> &variances,
15-
float clip)
16-
: clip(clip),
17-
rms(means, variances){}
18-
19-
ObservationNormalizer::ObservationNormalizer(const std::vector<ObservationNormalizer> &others)
20-
: clip(0),
21-
rms(1)
9+
ObservationNormalizerImpl::ObservationNormalizerImpl(int size, float clip)
10+
: clip(register_buffer("clip", torch::full({1}, clip, torch::kFloat))),
11+
rms(register_module("rms", RunningMeanStd(size))) {}
12+
13+
ObservationNormalizerImpl::ObservationNormalizerImpl(const std::vector<float> &means,
14+
const std::vector<float> &variances,
15+
float clip)
16+
: clip(register_buffer("clip", torch::full({1}, clip, torch::kFloat))),
17+
rms(register_module("rms", RunningMeanStd(means, variances))) {}
18+
19+
ObservationNormalizerImpl::ObservationNormalizerImpl(const std::vector<ObservationNormalizer> &others)
20+
: clip(register_buffer("clip", torch::zeros({1}, torch::kFloat))),
21+
rms(register_module("rms", RunningMeanStd(1)))
2222
{
2323
// Calculate mean clip
2424
for (const auto &other : others)
2525
{
26-
clip += other.get_clip_value();
26+
clip += other->get_clip_value();
2727
}
28-
clip /= others.size();
28+
clip[0] = clip[0] / static_cast<float>(others.size());
2929

3030
// Calculate mean mean
31-
std::vector<float> mean_means(others[0].get_mean().size(), 0);
31+
std::vector<float> mean_means(others[0]->get_mean().size(), 0);
3232
for (const auto &other : others)
3333
{
34-
auto other_mean = other.get_mean();
34+
auto other_mean = other->get_mean();
3535
for (unsigned int i = 0; i < mean_means.size(); ++i)
3636
{
3737
mean_means[i] += other_mean[i];
@@ -43,10 +43,10 @@ ObservationNormalizer::ObservationNormalizer(const std::vector<ObservationNormal
4343
}
4444

4545
// Calculate mean variances
46-
std::vector<float> mean_variances(others[0].get_variance().size(), 0);
46+
std::vector<float> mean_variances(others[0]->get_variance().size(), 0);
4747
for (const auto &other : others)
4848
{
49-
auto other_variances = other.get_variance();
49+
auto other_variances = other->get_variance();
5050
for (unsigned int i = 0; i < mean_variances.size(); ++i)
5151
{
5252
mean_variances[i] += other_variances[i];
@@ -61,33 +61,33 @@ ObservationNormalizer::ObservationNormalizer(const std::vector<ObservationNormal
6161

6262
int total_count = std::accumulate(others.begin(), others.end(), 0,
6363
[](int accumulator, const ObservationNormalizer &other) {
64-
return accumulator + other.get_step_count();
64+
return accumulator + other->get_step_count();
6565
});
66-
rms.set_count(total_count);
66+
rms->set_count(total_count);
6767
}
6868

69-
torch::Tensor ObservationNormalizer::process_observation(torch::Tensor observation)
69+
torch::Tensor ObservationNormalizerImpl::process_observation(torch::Tensor observation)
7070
{
71-
auto normalized_obs = (observation - rms.get_mean()) /
72-
torch::sqrt(rms.get_variance() + 1e-8);
73-
return torch::clamp(normalized_obs, -clip, clip);
71+
auto normalized_obs = (observation - rms->get_mean()) /
72+
torch::sqrt(rms->get_variance() + 1e-8);
73+
return torch::clamp(normalized_obs, -clip.item(), clip.item());
7474
}
7575

76-
std::vector<float> ObservationNormalizer::get_mean() const
76+
std::vector<float> ObservationNormalizerImpl::get_mean() const
7777
{
78-
auto mean = rms.get_mean();
78+
auto mean = rms->get_mean();
7979
return std::vector<float>(mean.data<float>(), mean.data<float>() + mean.numel());
8080
}
8181

82-
std::vector<float> ObservationNormalizer::get_variance() const
82+
std::vector<float> ObservationNormalizerImpl::get_variance() const
8383
{
84-
auto variance = rms.get_variance();
84+
auto variance = rms->get_variance();
8585
return std::vector<float>(variance.data<float>(), variance.data<float>() + variance.numel());
8686
}
8787

88-
void ObservationNormalizer::update(torch::Tensor observations)
88+
void ObservationNormalizerImpl::update(torch::Tensor observations)
8989
{
90-
rms.update(observations);
90+
rms->update(observations);
9191
}
9292

9393
TEST_CASE("ObservationNormalizer")
@@ -97,7 +97,7 @@ TEST_CASE("ObservationNormalizer")
9797
ObservationNormalizer normalizer(7, 1);
9898
float observation_array[] = {-1000, -100, -10, 0, 10, 100, 1000};
9999
auto observation = torch::from_blob(observation_array, {7});
100-
auto processed_observation = normalizer.process_observation(observation);
100+
auto processed_observation = normalizer->process_observation(observation);
101101

102102
auto has_too_large_values = (processed_observation > 1).any().item().toBool();
103103
auto has_too_small_values = (processed_observation < -1).any().item().toBool();
@@ -116,10 +116,10 @@ TEST_CASE("ObservationNormalizer")
116116
auto obs_2 = torch::from_blob(obs_2_array, {5});
117117
auto obs_3 = torch::from_blob(obs_3_array, {5});
118118

119-
normalizer.update(obs_1);
120-
normalizer.update(obs_2);
121-
normalizer.update(obs_3);
122-
auto processed_observation = normalizer.process_observation(obs_3);
119+
normalizer->update(obs_1);
120+
normalizer->update(obs_2);
121+
normalizer->update(obs_3);
122+
auto processed_observation = normalizer->process_observation(obs_3);
123123

124124
DOCTEST_CHECK(processed_observation[0].item().toFloat() == doctest::Approx(1.26008659));
125125
DOCTEST_CHECK(processed_observation[1].item().toFloat() == doctest::Approx(0.70712887));
@@ -130,10 +130,10 @@ TEST_CASE("ObservationNormalizer")
130130

131131
SUBCASE("Loads mean and variance from constructor correctly")
132132
{
133-
ObservationNormalizer normalizer({1, 2, 3}, {4, 5, 6});
133+
ObservationNormalizer normalizer(std::vector<float>({1, 2, 3}), std::vector<float>({4, 5, 6}));
134134

135-
auto mean = normalizer.get_mean();
136-
auto variance = normalizer.get_variance();
135+
auto mean = normalizer->get_mean();
136+
auto variance = normalizer->get_variance();
137137
DOCTEST_CHECK(mean[0] == doctest::Approx(1));
138138
DOCTEST_CHECK(mean[1] == doctest::Approx(2));
139139
DOCTEST_CHECK(mean[2] == doctest::Approx(3));
@@ -150,18 +150,18 @@ TEST_CASE("ObservationNormalizer")
150150
normalizers.push_back(ObservationNormalizer(3));
151151
for (int j = 0; j <= i; ++j)
152152
{
153-
normalizers[i].update(torch::rand({3}));
153+
normalizers[i]->update(torch::rand({3}));
154154
}
155155
}
156156

157157
ObservationNormalizer combined_normalizer(normalizers);
158158

159159
std::vector<std::vector<float>> means;
160160
std::transform(normalizers.begin(), normalizers.end(), std::back_inserter(means),
161-
[](const ObservationNormalizer &normalizer) { return normalizer.get_mean(); });
161+
[](const ObservationNormalizer &normalizer) { return normalizer->get_mean(); });
162162
std::vector<std::vector<float>> variances;
163163
std::transform(normalizers.begin(), normalizers.end(), std::back_inserter(variances),
164-
[](const ObservationNormalizer &normalizer) { return normalizer.get_variance(); });
164+
[](const ObservationNormalizer &normalizer) { return normalizer->get_variance(); });
165165

166166
std::vector<float> mean_means;
167167
for (int i = 0; i < 3; ++i)
@@ -174,15 +174,15 @@ TEST_CASE("ObservationNormalizer")
174174
mean_variances.push_back((variances[0][i] + variances[1][i] + variances[2][i]) / 3);
175175
}
176176

177-
auto actual_mean_means = combined_normalizer.get_mean();
178-
auto actual_mean_variances = combined_normalizer.get_variance();
177+
auto actual_mean_means = combined_normalizer->get_mean();
178+
auto actual_mean_variances = combined_normalizer->get_variance();
179179

180180
for (int i = 0; i < 3; ++i)
181181
{
182182
DOCTEST_CHECK(actual_mean_means[i] == doctest::Approx(mean_means[i]));
183183
DOCTEST_CHECK(actual_mean_variances[i] == doctest::Approx(actual_mean_variances[i]));
184184
}
185-
DOCTEST_CHECK(combined_normalizer.get_step_count() == 6);
185+
DOCTEST_CHECK(combined_normalizer->get_step_count() == 6);
186186
}
187187
}
188188
}

src/running_mean_std.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55

66
namespace SingularityTrainer
77
{
8-
RunningMeanStd::RunningMeanStd(int size)
9-
: count(1e-4),
10-
mean(torch::zeros({size})),
11-
variance(torch::ones({size})) {}
8+
RunningMeanStdImpl::RunningMeanStdImpl(int size)
9+
: count(register_buffer("count", torch::full({1}, 1e-4, torch::kFloat))),
10+
mean(register_buffer("mean", torch::zeros({size}))),
11+
variance(register_buffer("variance", torch::ones({size}))) {}
1212

13-
RunningMeanStd::RunningMeanStd(std::vector<float> means, std::vector<float> variances)
14-
: count(1e-4),
15-
mean(torch::from_blob(means.data(), {static_cast<long>(means.size())})
16-
.clone()),
17-
variance(torch::from_blob(variances.data(), {static_cast<long>(variances.size())})
18-
.clone()) {}
13+
RunningMeanStdImpl::RunningMeanStdImpl(std::vector<float> means, std::vector<float> variances)
14+
: count(register_buffer("count", torch::full({1}, 1e-4, torch::kFloat))),
15+
mean(register_buffer("mean", torch::from_blob(means.data(), {static_cast<long>(means.size())})
16+
.clone())),
17+
variance(register_buffer("variance", torch::from_blob(variances.data(), {static_cast<long>(variances.size())})
18+
.clone())) {}
1919

20-
void RunningMeanStd::update(torch::Tensor observation)
20+
void RunningMeanStdImpl::update(torch::Tensor observation)
2121
{
2222
observation = observation.reshape({-1, mean.size(0)});
2323
auto batch_mean = observation.mean(0);
@@ -27,12 +27,12 @@ void RunningMeanStd::update(torch::Tensor observation)
2727
update_from_moments(batch_mean, batch_var, batch_count);
2828
}
2929

30-
void RunningMeanStd::update_from_moments(torch::Tensor batch_mean,
31-
torch::Tensor batch_var,
32-
int batch_count)
30+
void RunningMeanStdImpl::update_from_moments(torch::Tensor batch_mean,
31+
torch::Tensor batch_var,
32+
int batch_count)
3333
{
3434
auto delta = batch_mean - mean;
35-
float total_count = count + batch_count;
35+
auto total_count = count + batch_count;
3636

3737
mean = mean + delta * batch_count / total_count;
3838
auto m_a = variance * count;
@@ -48,15 +48,15 @@ TEST_CASE("RunningMeanStd")
4848
{
4949
RunningMeanStd rms(5);
5050
auto observations = torch::rand({3, 5});
51-
rms.update(observations[0]);
52-
rms.update(observations[1]);
53-
rms.update(observations[2]);
51+
rms->update(observations[0]);
52+
rms->update(observations[1]);
53+
rms->update(observations[2]);
5454

5555
auto expected_mean = observations.mean(0);
5656
auto expected_variance = observations.var(0, false, false);
5757

58-
auto actual_mean = rms.get_mean();
59-
auto actual_variance = rms.get_variance();
58+
auto actual_mean = rms->get_mean();
59+
auto actual_variance = rms->get_variance();
6060

6161
for (int i = 0; i < 5; ++i)
6262
{
@@ -71,10 +71,10 @@ TEST_CASE("RunningMeanStd")
7171

7272
SUBCASE("Loads mean and variance from constructor correctly")
7373
{
74-
RunningMeanStd rms({1, 2, 3}, {4, 5, 6});
74+
RunningMeanStd rms(std::vector<float>{1, 2, 3}, std::vector<float>{4, 5, 6});
7575

76-
auto mean = rms.get_mean();
77-
auto variance = rms.get_variance();
76+
auto mean = rms->get_mean();
77+
auto variance = rms->get_variance();
7878
DOCTEST_CHECK(mean[0].item().toFloat() == doctest::Approx(1));
7979
DOCTEST_CHECK(mean[1].item().toFloat() == doctest::Approx(2));
8080
DOCTEST_CHECK(mean[2].item().toFloat() == doctest::Approx(3));

0 commit comments

Comments
 (0)