Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions example/gym_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,24 @@ using namespace cpprl;

// Algorithm hyperparameters
const std::string algorithm = "PPO";
const int batch_size = 40;
const int batch_size = 512;
const float clip_param = 0.2;
const float discount_factor = 0.99;
const float entropy_coef = 1e-3;
const float learning_rate = 1e-3;
const float entropy_coef = 1e-5;
const float gae = 0.95;
const float learning_rate = 3e-4;
const int log_interval = 1;
const int num_epoch = 3;
const int num_mini_batch = 20;
const int num_mini_batch = 32;
const int reward_average_window_size = 10;
const bool use_gae = true;
const float value_loss_coef = 0.5;

// Environment hyperparameters
const std::string env_name = "LunarLander-v2";
const int num_envs = 8;
const float env_gamma = discount_factor; // Set to -1 to disable
const std::string env_name = "BipedalWalker-v2";
const int num_envs = 8;
const float render_reward_threshold = 300;

// Model hyperparameters
const int hidden_size = 64;
Expand Down Expand Up @@ -117,7 +120,7 @@ int main(int argc, char *argv[])
base = std::make_shared<CnnBase>(env_info->observation_space_shape[0], recurrent, hidden_size);
}
base->to(device);
ActionSpace space{"Discrete", env_info->action_space_shape};
ActionSpace space{env_info->action_space_type, env_info->action_space_shape};
Policy policy(space, base);
policy->to(device);
RolloutStorage storage(batch_size, num_envs, env_info->observation_space_shape, space, hidden_size, device);
Expand Down Expand Up @@ -152,11 +155,14 @@ int main(int argc, char *argv[])
storage.get_masks()[step]);
}
auto actions_tensor = act_result[1].cpu();
int64_t *actions_array = actions_tensor.data<int64_t>();
std::vector<std::vector<int>> actions(num_envs);
float *actions_array = actions_tensor.data<float>();
std::vector<std::vector<float>> actions(num_envs);
for (int i = 0; i < num_envs; ++i)
{
actions[i] = {static_cast<int>(actions_array[i])};
for (int j = 0; j < env_info->action_space_shape[0]; j++)
{
actions[i].push_back(actions_array[i * env_info->action_space_shape[0] + j]);
}
}

auto step_param = std::make_shared<StepParam>();
Expand Down Expand Up @@ -219,12 +225,12 @@ int main(int argc, char *argv[])
storage.get_masks()[-1])
.detach();
}
storage.compute_returns(next_value, use_gae, discount_factor, 0.9);
storage.compute_returns(next_value, use_gae, discount_factor, gae);

auto update_data = algo->update(storage);
storage.after_update();

if (update % 10 == 0 && update > 0)
if (update % log_interval == 0 && update > 0)
{
auto total_steps = (update + 1) * batch_size * num_envs;
auto run_time = std::chrono::high_resolution_clock::now() - start_time;
Expand All @@ -241,7 +247,7 @@ int main(int argc, char *argv[])
float average_reward = std::accumulate(reward_history.begin(), reward_history.end(), 0);
average_reward /= episode_count < reward_average_window_size ? episode_count : reward_average_window_size;
spdlog::info("Reward: {}", average_reward);
render = average_reward > 180;
render = average_reward >= render_reward_threshold;
}
}
}
2 changes: 1 addition & 1 deletion example/requests.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct ResetParam

struct StepParam
{
std::vector<std::vector<int>> actions;
std::vector<std::vector<float>> actions;
bool render;
MSGPACK_DEFINE_MAP(actions, render);
};
Expand Down
5 changes: 2 additions & 3 deletions gym_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def step(self,
if isinstance(self.env.action_space, gym.spaces.Discrete):
actions = actions.squeeze(-1)
observation, reward, done, info = self.env.step(actions)
if isinstance(self.env.action_space, gym.spaces.Discrete):
reward = np.expand_dims(reward, -1)
done = np.expand_dims(done, -1)
reward = np.expand_dims(reward, -1)
done = np.expand_dims(done, -1)
if render:
self.env.render()
return observation, reward, done, info
Expand Down
2 changes: 0 additions & 2 deletions include/cpprl/distributions/distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ class Distribution
virtual ~Distribution() = 0;

virtual torch::Tensor entropy() = 0;
virtual torch::Tensor get_logits() = 0;
virtual torch::Tensor get_probs() = 0;
virtual torch::Tensor log_prob(torch::Tensor value) = 0;
virtual torch::Tensor sample(c10::ArrayRef<int64_t> sample_shape = {}) = 0;
};
Expand Down
28 changes: 28 additions & 0 deletions include/cpprl/distributions/normal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include <c10/util/ArrayRef.h>
#include <torch/torch.h>

#include "cpprl/distributions/distribution.h"

namespace cpprl
{
class Normal : public Distribution
{
private:
torch::Tensor loc, scale;
std::vector<int64_t> batch_shape, event_shape;

std::vector<int64_t> extended_shape(c10::ArrayRef<int64_t> sample_shape);

public:
Normal(const torch::Tensor loc, const torch::Tensor scale);

torch::Tensor entropy();
torch::Tensor log_prob(torch::Tensor value);
torch::Tensor sample(c10::ArrayRef<int64_t> sample_shape = {});

inline torch::Tensor get_loc() { return loc; }
inline torch::Tensor get_scale() { return scale; }
};
}
12 changes: 12 additions & 0 deletions include/cpprl/model/output_layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,16 @@ class CategoricalOutput : public OutputLayer

std::unique_ptr<Distribution> forward(torch::Tensor x);
};

class NormalOutput : public OutputLayer
{
private:
nn::Linear linear_loc;
torch::Tensor scale_log;

public:
NormalOutput(unsigned int num_inputs, unsigned int num_outputs);

std::unique_ptr<Distribution> forward(torch::Tensor x);
};
}
4 changes: 2 additions & 2 deletions include/cpprl/model/policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@

#include "cpprl/model/nn_base.h"
#include "cpprl/model/output_layers.h"
#include "cpprl/spaces.h"

using namespace torch;

namespace cpprl
{
class ActionSpace;

class PolicyImpl : public nn::Module
{
private:
std::shared_ptr<NNBase> base;
std::shared_ptr<OutputLayer> output_layer;
ActionSpace action_space;

std::vector<torch::Tensor> forward_gru(torch::Tensor x,
torch::Tensor hxs,
Expand Down
2 changes: 2 additions & 0 deletions src/distributions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
target_sources(cpprl
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/categorical.cpp
${CMAKE_CURRENT_LIST_DIR}/normal.cpp
)

if (CPPRL_BUILD_TESTS)
target_sources(cpprl_tests
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/categorical.cpp
${CMAKE_CURRENT_LIST_DIR}/normal.cpp
)
endif (CPPRL_BUILD_TESTS)
127 changes: 127 additions & 0 deletions src/distributions/normal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#define _USE_MATH_DEFINES
#include <math.h>
#include <cmath>
#include <limits>

#include <c10/util/ArrayRef.h>
#include <torch/torch.h>

#include "cpprl/distributions/normal.h"
#include "third_party/doctest.h"

namespace cpprl
{
Normal::Normal(const torch::Tensor loc,
const torch::Tensor scale)
{
auto broadcasted_tensors = torch::broadcast_tensors({loc, scale});
this->loc = broadcasted_tensors[0];
this->scale = broadcasted_tensors[1];
batch_shape = this->loc.sizes().vec();
event_shape = {};
}

torch::Tensor Normal::entropy()
{
return (0.5 + 0.5 * std::log(2 * M_PI) + torch::log(scale)).sum(-1);
}

std::vector<int64_t> Normal::extended_shape(c10::ArrayRef<int64_t> sample_shape)
{
std::vector<int64_t> output_shape;
output_shape.insert(output_shape.end(),
sample_shape.begin(),
sample_shape.end());
output_shape.insert(output_shape.end(),
batch_shape.begin(),
batch_shape.end());
output_shape.insert(output_shape.end(),
event_shape.begin(),
event_shape.end());
return output_shape;
}

torch::Tensor Normal::log_prob(torch::Tensor value)
{
auto variance = scale.pow(2);
auto log_scale = scale.log();
return (-(value - loc).pow(2) /
(2 * variance) -
log_scale -
std::log(std::sqrt(2 * M_PI)));
}

torch::Tensor Normal::sample(c10::ArrayRef<int64_t> sample_shape)
{
auto shape = extended_shape(sample_shape);
auto no_grad_guard = torch::NoGradGuard();
return torch::normal(loc.expand(shape), scale.expand(shape));
}

TEST_CASE("Normal")
{
float locs_array[] = {0, 1, 2, 3, 4, 5};
float scales_array[] = {5, 4, 3, 2, 1, 0};
auto locs = torch::from_blob(locs_array, {2, 3});
auto scales = torch::from_blob(scales_array, {2, 3});
auto dist = Normal(locs, scales);

SUBCASE("Sampled tensors have correct shape")
{
CHECK(dist.sample().sizes().vec() == std::vector<int64_t>{2, 3});
CHECK(dist.sample({20}).sizes().vec() == std::vector<int64_t>{20, 2, 3});
CHECK(dist.sample({2, 20}).sizes().vec() == std::vector<int64_t>{2, 20, 2, 3});
CHECK(dist.sample({1, 2, 3, 4, 5}).sizes().vec() == std::vector<int64_t>{1, 2, 3, 4, 5, 2, 3});
}

SUBCASE("entropy()")
{
auto entropies = dist.entropy();

SUBCASE("Returns correct values")
{
INFO("Entropies: \n"
<< entropies);

CHECK(entropies[0].item().toDouble() ==
doctest::Approx(8.3512).epsilon(1e-3));
CHECK(entropies[1].item().toDouble() ==
-std::numeric_limits<float>::infinity());
}

SUBCASE("Output tensor is the correct size")
{
CHECK(entropies.sizes().vec() == std::vector<int64_t>{2});
}
}

SUBCASE("log_prob()")
{
float actions[2][3] = {{0, 1, 2},
{0, 1, 2}};
auto actions_tensor = torch::from_blob(actions, {2, 3});
auto log_probs = dist.log_prob(actions_tensor);

INFO(log_probs << "\n");
SUBCASE("Returns correct values")
{
CHECK(log_probs[0][0].item().toDouble() ==
doctest::Approx(-2.5284).epsilon(1e-3));
CHECK(log_probs[0][1].item().toDouble() ==
doctest::Approx(-2.3052).epsilon(1e-3));
CHECK(log_probs[0][2].item().toDouble() ==
doctest::Approx(-2.0176).epsilon(1e-3));
CHECK(log_probs[1][0].item().toDouble() ==
doctest::Approx(-2.7371).epsilon(1e-3));
CHECK(log_probs[1][1].item().toDouble() ==
doctest::Approx(-5.4189).epsilon(1e-3));
CHECK(std::isnan(log_probs[1][2].item().toDouble()));
}

SUBCASE("Output tensor is correct size")
{
CHECK(log_probs.sizes().vec() == std::vector<int64_t>{2, 3});
}
}
}
}
2 changes: 1 addition & 1 deletion src/model/nn_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ NNBase::NNBase(bool recurrent,

// Do not use.
//
// Instantiate a subclass and use their's instead
// Instantiate a subclass and use theirs instead
std::vector<torch::Tensor> NNBase::forward(torch::Tensor /*inputs*/,
torch::Tensor /*hxs*/,
torch::Tensor /*masks*/)
Expand Down
36 changes: 35 additions & 1 deletion src/model/output_layers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "cpprl/model/model_utils.h"
#include "cpprl/distributions/distribution.h"
#include "cpprl/distributions/categorical.h"
#include "cpprl/distributions/normal.h"
#include "third_party/doctest.h"

using namespace torch;
Expand All @@ -22,11 +23,26 @@ CategoricalOutput::CategoricalOutput(unsigned int num_inputs,

std::unique_ptr<Distribution> CategoricalOutput::forward(torch::Tensor x)
{
auto y = x;
x = linear(x);
return std::make_unique<Categorical>(nullptr, &x);
}

NormalOutput::NormalOutput(unsigned int num_inputs,
unsigned int num_outputs)
: linear_loc(num_inputs, num_outputs)
{
register_module("linear_loc", linear_loc);
scale_log = register_parameter("scale_log", torch::zeros({num_outputs}));
init_weights(linear_loc->named_parameters(), 1, 0);
}

std::unique_ptr<Distribution> NormalOutput::forward(torch::Tensor x)
{
auto loc = linear_loc(x);
auto scale = scale_log.exp();
return std::make_unique<Normal>(loc, scale);
}

TEST_CASE("CategoricalOutput")
{
auto output_layer = CategoricalOutput(3, 5);
Expand All @@ -44,4 +60,22 @@ TEST_CASE("CategoricalOutput")
CHECK(output.sizes().vec() == std::vector<int64_t>{2});
}
}

TEST_CASE("NormalOutput")
{
auto output_layer = NormalOutput(3, 5);

SUBCASE("Output distribution has correct output shape")
{
float input_array[2][3] = {{0, 1, 2}, {3, 4, 5}};
auto input_tensor = torch::from_blob(input_array,
{2, 3},
TensorOptions(torch::kFloat));
auto dist = output_layer.forward(input_tensor);

auto output = dist->sample();

CHECK(output.sizes().vec() == std::vector<int64_t>{2, 5});
}
}
}
Loading