1+ #include < torch/torch.h>
2+
3+ #include " cpprl/observation_normalizer.h"
4+ #include " cpprl/running_mean_std.h"
5+ #include " third_party/doctest.h"
6+
7+ namespace SingularityTrainer
8+ {
9+ ObservationNormalizer::ObservationNormalizer (int size, float clip)
10+ : clip(clip),
11+ rms (size) {}
12+
13+ ObservationNormalizer::ObservationNormalizer (const std::vector<float > &means,
14+ const std::vector<float > &variances,
15+ float clip)
16+ : clip(clip),
17+ rms(means, variances){}
18+
19+ ObservationNormalizer::ObservationNormalizer (const std::vector<ObservationNormalizer> &others)
20+ : clip(0 ),
21+ rms(1 )
22+ {
23+ // Calculate mean clip
24+ for (const auto &other : others)
25+ {
26+ clip += other.get_clip_value ();
27+ }
28+ clip /= others.size ();
29+
30+ // Calculate mean mean
31+ std::vector<float > mean_means (others[0 ].get_mean ().size (), 0 );
32+ for (const auto &other : others)
33+ {
34+ auto other_mean = other.get_mean ();
35+ for (unsigned int i = 0 ; i < mean_means.size (); ++i)
36+ {
37+ mean_means[i] += other_mean[i];
38+ }
39+ }
40+ for (auto &mean : mean_means)
41+ {
42+ mean /= others.size ();
43+ }
44+
45+ // Calculate mean variances
46+ std::vector<float > mean_variances (others[0 ].get_variance ().size (), 0 );
47+ for (const auto &other : others)
48+ {
49+ auto other_variances = other.get_variance ();
50+ for (unsigned int i = 0 ; i < mean_variances.size (); ++i)
51+ {
52+ mean_variances[i] += other_variances[i];
53+ }
54+ }
55+ for (auto &variance : mean_variances)
56+ {
57+ variance /= others.size ();
58+ }
59+
60+ rms = RunningMeanStd (mean_means, mean_variances);
61+
62+ int total_count = std::accumulate (others.begin (), others.end (), 0 ,
63+ [](int accumulator, const ObservationNormalizer &other) {
64+ return accumulator + other.get_step_count ();
65+ });
66+ rms.set_count (total_count);
67+ }
68+
69+ torch::Tensor ObservationNormalizer::process_observation (torch::Tensor observation)
70+ {
71+ auto normalized_obs = (observation - rms.get_mean ()) /
72+ torch::sqrt (rms.get_variance () + 1e-8 );
73+ return torch::clamp (normalized_obs, -clip, clip);
74+ }
75+
76+ std::vector<float > ObservationNormalizer::get_mean () const
77+ {
78+ auto mean = rms.get_mean ();
79+ return std::vector<float >(mean.data <float >(), mean.data <float >() + mean.numel ());
80+ }
81+
82+ std::vector<float > ObservationNormalizer::get_variance () const
83+ {
84+ auto variance = rms.get_variance ();
85+ return std::vector<float >(variance.data <float >(), variance.data <float >() + variance.numel ());
86+ }
87+
88+ void ObservationNormalizer::update (torch::Tensor observations)
89+ {
90+ rms.update (observations);
91+ }
92+
93+ TEST_CASE (" ObservationNormalizer" )
94+ {
95+ SUBCASE (" Clips values correctly" )
96+ {
97+ ObservationNormalizer normalizer (7 , 1 );
98+ float observation_array[] = {-1000 , -100 , -10 , 0 , 10 , 100 , 1000 };
99+ auto observation = torch::from_blob (observation_array, {7 });
100+ auto processed_observation = normalizer.process_observation (observation);
101+
102+ auto has_too_large_values = (processed_observation > 1 ).any ().item ().toBool ();
103+ auto has_too_small_values = (processed_observation < -1 ).any ().item ().toBool ();
104+ DOCTEST_CHECK (!has_too_large_values);
105+ DOCTEST_CHECK (!has_too_small_values);
106+ }
107+
108+ SUBCASE (" Normalizes values correctly" )
109+ {
110+ ObservationNormalizer normalizer (5 );
111+
112+ float obs_1_array[] = {-10 ., 0 ., 5 ., 3.2 , 0 .};
113+ float obs_2_array[] = {-5 ., 2 ., 4 ., 3.7 , -3 .};
114+ float obs_3_array[] = {1 , 2 , 3 , 4 , 5 };
115+ auto obs_1 = torch::from_blob (obs_1_array, {5 });
116+ auto obs_2 = torch::from_blob (obs_2_array, {5 });
117+ auto obs_3 = torch::from_blob (obs_3_array, {5 });
118+
119+ normalizer.update (obs_1);
120+ normalizer.update (obs_2);
121+ normalizer.update (obs_3);
122+ auto processed_observation = normalizer.process_observation (obs_3);
123+
124+ DOCTEST_CHECK (processed_observation[0 ].item ().toFloat () == doctest::Approx (1.26008659 ));
125+ DOCTEST_CHECK (processed_observation[1 ].item ().toFloat () == doctest::Approx (0.70712887 ));
126+ DOCTEST_CHECK (processed_observation[2 ].item ().toFloat () == doctest::Approx (-1.2240818 ));
127+ DOCTEST_CHECK (processed_observation[3 ].item ().toFloat () == doctest::Approx (1.10914509 ));
128+ DOCTEST_CHECK (processed_observation[4 ].item ().toFloat () == doctest::Approx (1.31322402 ));
129+ }
130+
131+ SUBCASE (" Loads mean and variance from constructor correctly" )
132+ {
133+ ObservationNormalizer normalizer ({1 , 2 , 3 }, {4 , 5 , 6 });
134+
135+ auto mean = normalizer.get_mean ();
136+ auto variance = normalizer.get_variance ();
137+ DOCTEST_CHECK (mean[0 ] == doctest::Approx (1 ));
138+ DOCTEST_CHECK (mean[1 ] == doctest::Approx (2 ));
139+ DOCTEST_CHECK (mean[2 ] == doctest::Approx (3 ));
140+ DOCTEST_CHECK (variance[0 ] == doctest::Approx (4 ));
141+ DOCTEST_CHECK (variance[1 ] == doctest::Approx (5 ));
142+ DOCTEST_CHECK (variance[2 ] == doctest::Approx (6 ));
143+ }
144+
145+ SUBCASE (" Is constructed from other normalizers correctly" )
146+ {
147+ std::vector<ObservationNormalizer> normalizers;
148+ for (int i = 0 ; i < 3 ; ++i)
149+ {
150+ normalizers.push_back (ObservationNormalizer (3 ));
151+ for (int j = 0 ; j <= i; ++j)
152+ {
153+ normalizers[i].update (torch::rand ({3 }));
154+ }
155+ }
156+
157+ ObservationNormalizer combined_normalizer (normalizers);
158+
159+ std::vector<std::vector<float >> means;
160+ std::transform (normalizers.begin (), normalizers.end (), std::back_inserter (means),
161+ [](const ObservationNormalizer &normalizer) { return normalizer.get_mean (); });
162+ std::vector<std::vector<float >> variances;
163+ std::transform (normalizers.begin (), normalizers.end (), std::back_inserter (variances),
164+ [](const ObservationNormalizer &normalizer) { return normalizer.get_variance (); });
165+
166+ std::vector<float > mean_means;
167+ for (int i = 0 ; i < 3 ; ++i)
168+ {
169+ mean_means.push_back ((means[0 ][i] + means[1 ][i] + means[2 ][i]) / 3 );
170+ }
171+ std::vector<float > mean_variances;
172+ for (int i = 0 ; i < 3 ; ++i)
173+ {
174+ mean_variances.push_back ((variances[0 ][i] + variances[1 ][i] + variances[2 ][i]) / 3 );
175+ }
176+
177+ auto actual_mean_means = combined_normalizer.get_mean ();
178+ auto actual_mean_variances = combined_normalizer.get_variance ();
179+
180+ for (int i = 0 ; i < 3 ; ++i)
181+ {
182+ DOCTEST_CHECK (actual_mean_means[i] == doctest::Approx (mean_means[i]));
183+ DOCTEST_CHECK (actual_mean_variances[i] == doctest::Approx (actual_mean_variances[i]));
184+ }
185+ DOCTEST_CHECK (combined_normalizer.get_step_count () == 6 );
186+ }
187+ }
188+ }
0 commit comments