66
77namespace SingularityTrainer
88{
9- ObservationNormalizer::ObservationNormalizer (int size, float clip)
10- : clip(clip),
11- rms (size) {}
12-
13- ObservationNormalizer::ObservationNormalizer (const std::vector<float > &means,
14- const std::vector<float > &variances,
15- float clip)
16- : clip(clip),
17- rms(means, variances){}
18-
19- ObservationNormalizer::ObservationNormalizer (const std::vector<ObservationNormalizer> &others)
20- : clip(0 ),
21- rms(1 )
9+ ObservationNormalizerImpl::ObservationNormalizerImpl (int size, float clip)
10+ : clip(register_buffer( " clip" , torch::full({ 1 }, clip, torch:: kFloat )) ),
11+ rms (register_module( " rms " , RunningMeanStd( size)) ) {}
12+
13+ ObservationNormalizerImpl::ObservationNormalizerImpl (const std::vector<float > &means,
14+ const std::vector<float > &variances,
15+ float clip)
16+ : clip(register_buffer( " clip" , torch::full({ 1 }, clip, torch:: kFloat )) ),
17+ rms (register_module( " rms " , RunningMeanStd( means, variances))) {}
18+
19+ ObservationNormalizerImpl::ObservationNormalizerImpl (const std::vector<ObservationNormalizer> &others)
20+ : clip(register_buffer( " clip " , torch::zeros({ 1 }, torch:: kFloat )) ),
21+ rms (register_module( " rms " , RunningMeanStd( 1 )) )
2222{
2323 // Calculate mean clip
2424 for (const auto &other : others)
2525 {
26- clip += other. get_clip_value ();
26+ clip += other-> get_clip_value ();
2727 }
28- clip /= others.size ();
28+ clip[ 0 ] = clip[ 0 ] / static_cast < float >( others.size () );
2929
3030 // Calculate mean mean
31- std::vector<float > mean_means (others[0 ]. get_mean ().size (), 0 );
31+ std::vector<float > mean_means (others[0 ]-> get_mean ().size (), 0 );
3232 for (const auto &other : others)
3333 {
34- auto other_mean = other. get_mean ();
34+ auto other_mean = other-> get_mean ();
3535 for (unsigned int i = 0 ; i < mean_means.size (); ++i)
3636 {
3737 mean_means[i] += other_mean[i];
@@ -43,10 +43,10 @@ ObservationNormalizer::ObservationNormalizer(const std::vector<ObservationNormal
4343 }
4444
4545 // Calculate mean variances
46- std::vector<float > mean_variances (others[0 ]. get_variance ().size (), 0 );
46+ std::vector<float > mean_variances (others[0 ]-> get_variance ().size (), 0 );
4747 for (const auto &other : others)
4848 {
49- auto other_variances = other. get_variance ();
49+ auto other_variances = other-> get_variance ();
5050 for (unsigned int i = 0 ; i < mean_variances.size (); ++i)
5151 {
5252 mean_variances[i] += other_variances[i];
@@ -61,33 +61,33 @@ ObservationNormalizer::ObservationNormalizer(const std::vector<ObservationNormal
6161
6262 int total_count = std::accumulate (others.begin (), others.end (), 0 ,
6363 [](int accumulator, const ObservationNormalizer &other) {
64- return accumulator + other. get_step_count ();
64+ return accumulator + other-> get_step_count ();
6565 });
66- rms. set_count (total_count);
66+ rms-> set_count (total_count);
6767}
6868
69- torch::Tensor ObservationNormalizer ::process_observation (torch::Tensor observation)
69+ torch::Tensor ObservationNormalizerImpl ::process_observation (torch::Tensor observation)
7070{
71- auto normalized_obs = (observation - rms. get_mean ()) /
72- torch::sqrt (rms. get_variance () + 1e-8 );
73- return torch::clamp (normalized_obs, -clip, clip);
71+ auto normalized_obs = (observation - rms-> get_mean ()) /
72+ torch::sqrt (rms-> get_variance () + 1e-8 );
73+ return torch::clamp (normalized_obs, -clip. item () , clip. item () );
7474}
7575
76- std::vector<float > ObservationNormalizer ::get_mean () const
76+ std::vector<float > ObservationNormalizerImpl ::get_mean () const
7777{
78- auto mean = rms. get_mean ();
78+ auto mean = rms-> get_mean ();
7979 return std::vector<float >(mean.data <float >(), mean.data <float >() + mean.numel ());
8080}
8181
82- std::vector<float > ObservationNormalizer ::get_variance () const
82+ std::vector<float > ObservationNormalizerImpl ::get_variance () const
8383{
84- auto variance = rms. get_variance ();
84+ auto variance = rms-> get_variance ();
8585 return std::vector<float >(variance.data <float >(), variance.data <float >() + variance.numel ());
8686}
8787
88- void ObservationNormalizer ::update (torch::Tensor observations)
88+ void ObservationNormalizerImpl ::update (torch::Tensor observations)
8989{
90- rms. update (observations);
90+ rms-> update (observations);
9191}
9292
9393TEST_CASE (" ObservationNormalizer" )
@@ -97,7 +97,7 @@ TEST_CASE("ObservationNormalizer")
9797 ObservationNormalizer normalizer (7 , 1 );
9898 float observation_array[] = {-1000 , -100 , -10 , 0 , 10 , 100 , 1000 };
9999 auto observation = torch::from_blob (observation_array, {7 });
100- auto processed_observation = normalizer. process_observation (observation);
100+ auto processed_observation = normalizer-> process_observation (observation);
101101
102102 auto has_too_large_values = (processed_observation > 1 ).any ().item ().toBool ();
103103 auto has_too_small_values = (processed_observation < -1 ).any ().item ().toBool ();
@@ -116,10 +116,10 @@ TEST_CASE("ObservationNormalizer")
116116 auto obs_2 = torch::from_blob (obs_2_array, {5 });
117117 auto obs_3 = torch::from_blob (obs_3_array, {5 });
118118
119- normalizer. update (obs_1);
120- normalizer. update (obs_2);
121- normalizer. update (obs_3);
122- auto processed_observation = normalizer. process_observation (obs_3);
119+ normalizer-> update (obs_1);
120+ normalizer-> update (obs_2);
121+ normalizer-> update (obs_3);
122+ auto processed_observation = normalizer-> process_observation (obs_3);
123123
124124 DOCTEST_CHECK (processed_observation[0 ].item ().toFloat () == doctest::Approx (1.26008659 ));
125125 DOCTEST_CHECK (processed_observation[1 ].item ().toFloat () == doctest::Approx (0.70712887 ));
@@ -130,10 +130,10 @@ TEST_CASE("ObservationNormalizer")
130130
131131 SUBCASE (" Loads mean and variance from constructor correctly" )
132132 {
133- ObservationNormalizer normalizer ({1 , 2 , 3 }, {4 , 5 , 6 });
133+ ObservationNormalizer normalizer (std::vector< float >( {1 , 2 , 3 }), std::vector< float >( {4 , 5 , 6 }) );
134134
135- auto mean = normalizer. get_mean ();
136- auto variance = normalizer. get_variance ();
135+ auto mean = normalizer-> get_mean ();
136+ auto variance = normalizer-> get_variance ();
137137 DOCTEST_CHECK (mean[0 ] == doctest::Approx (1 ));
138138 DOCTEST_CHECK (mean[1 ] == doctest::Approx (2 ));
139139 DOCTEST_CHECK (mean[2 ] == doctest::Approx (3 ));
@@ -150,18 +150,18 @@ TEST_CASE("ObservationNormalizer")
150150 normalizers.push_back (ObservationNormalizer (3 ));
151151 for (int j = 0 ; j <= i; ++j)
152152 {
153- normalizers[i]. update (torch::rand ({3 }));
153+ normalizers[i]-> update (torch::rand ({3 }));
154154 }
155155 }
156156
157157 ObservationNormalizer combined_normalizer (normalizers);
158158
159159 std::vector<std::vector<float >> means;
160160 std::transform (normalizers.begin (), normalizers.end (), std::back_inserter (means),
161- [](const ObservationNormalizer &normalizer) { return normalizer. get_mean (); });
161+ [](const ObservationNormalizer &normalizer) { return normalizer-> get_mean (); });
162162 std::vector<std::vector<float >> variances;
163163 std::transform (normalizers.begin (), normalizers.end (), std::back_inserter (variances),
164- [](const ObservationNormalizer &normalizer) { return normalizer. get_variance (); });
164+ [](const ObservationNormalizer &normalizer) { return normalizer-> get_variance (); });
165165
166166 std::vector<float > mean_means;
167167 for (int i = 0 ; i < 3 ; ++i)
@@ -174,15 +174,15 @@ TEST_CASE("ObservationNormalizer")
174174 mean_variances.push_back ((variances[0 ][i] + variances[1 ][i] + variances[2 ][i]) / 3 );
175175 }
176176
177- auto actual_mean_means = combined_normalizer. get_mean ();
178- auto actual_mean_variances = combined_normalizer. get_variance ();
177+ auto actual_mean_means = combined_normalizer-> get_mean ();
178+ auto actual_mean_variances = combined_normalizer-> get_variance ();
179179
180180 for (int i = 0 ; i < 3 ; ++i)
181181 {
182182 DOCTEST_CHECK (actual_mean_means[i] == doctest::Approx (mean_means[i]));
183183 DOCTEST_CHECK (actual_mean_variances[i] == doctest::Approx (actual_mean_variances[i]));
184184 }
185- DOCTEST_CHECK (combined_normalizer. get_step_count () == 6 );
185+ DOCTEST_CHECK (combined_normalizer-> get_step_count () == 6 );
186186 }
187187}
188188}
0 commit comments