Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[submodule "lib/spdlog"]
path = lib/spdlog
path = example/lib/spdlog
url = git@github.com:gabime/spdlog.git
[submodule "lib/msgpack-c"]
path = example/lib/msgpack-c
Expand Down
18 changes: 11 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ list(APPEND CPPCHECK_ARGS
--suppressions-list=${CMAKE_CURRENT_LIST_DIR}/CppCheckSuppressions.txt
-I ${CMAKE_CURRENT_LIST_DIR}/src
-I ${CMAKE_CURRENT_LIST_DIR}/include
-I ${CMAKE_CURRENT_LIST_DIR}/lib/spdlog/include
-I ${CMAKE_CURRENT_LIST_DIR}/example
${CMAKE_CURRENT_LIST_DIR}/src
${CMAKE_CURRENT_LIST_DIR}/example
Expand All @@ -42,8 +41,6 @@ find_package(Torch REQUIRED)
if (TORCH_CXX_FLAGS)
set(CMAKE_CXX_FLAGS ${TORCH_CXX_FLAGS})
endif()
## Spdlog
add_subdirectory(lib/spdlog)

# Define targets
add_library(cpprl STATIC "")
Expand All @@ -64,18 +61,25 @@ endif(MSVC)
set(CPPRL_INCLUDE_DIRS
include
src
lib/spdlog/include
${TORCH_INCLUDE_DIRS}
)
target_include_directories(cpprl PRIVATE ${CPPRL_INCLUDE_DIRS})
target_include_directories(cpprl_tests PRIVATE ${CPPRL_INCLUDE_DIRS})
if (CPPRL_BUILD_TESTS)
target_include_directories(cpprl_tests PRIVATE ${CPPRL_INCLUDE_DIRS})
endif(CPPRL_BUILD_TESTS)

# Linking
target_link_libraries(cpprl torch ${TORCH_LIBRARIES})
target_link_libraries(cpprl_tests torch ${TORCH_LIBRARIES})
target_link_libraries(cpprl torch ${TORCH_LIBRARIES})
if (CPPRL_BUILD_TESTS)
target_link_libraries(cpprl_tests torch ${TORCH_LIBRARIES})
endif(CPPRL_BUILD_TESTS)

# Example
add_subdirectory(example)
option(CPPRL_BUILD_EXAMPLE "Whether or not to build the CppRl Gym example" ON)
if (CPPRL_BUILD_EXAMPLE)
add_subdirectory(example)
endif(CPPRL_BUILD_EXAMPLE)

# Recurse into source tree
add_subdirectory(src)
4 changes: 4 additions & 0 deletions example/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ add_executable(gym_client gym_client.cpp communicator.cpp)
set(LIB_DIR ${CMAKE_CURRENT_LIST_DIR}/lib)
set(CPPZMQ_DIR ${LIB_DIR}/cppzmq)
set(MSGPACK_DIR ${LIB_DIR}/msgpack-c)
set(SPDLOG_DIR ${LIB_DIR}/spdlog)
set(ZMQ_DIR ${LIB_DIR}/libzmq)

# Spdlog
add_subdirectory(${SPDLOG_DIR})
# ZMQ
option(ZMQ_BUILD_TESTS "" OFF)
add_subdirectory(${ZMQ_DIR})
Expand All @@ -16,6 +19,7 @@ target_include_directories(gym_client
../lib/spdlog/include
${CPPZMQ_DIR}
${MSGPACK_DIR}/include
${SPDLOG_DIR}/include
${ZMQ_DIR}/include
)

Expand Down
1 change: 0 additions & 1 deletion src/algorithms/ppo.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <chrono>
#include <memory>

#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/algorithms/ppo.h"
Expand Down
8 changes: 3 additions & 5 deletions src/distributions/bernoulli.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <ATen/core/Reduction.h>
#include <c10/util/ArrayRef.h>
#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/distributions/bernoulli.h"
Expand All @@ -13,15 +12,14 @@ Bernoulli::Bernoulli(const torch::Tensor *probs,
{
if ((probs == nullptr) == (logits == nullptr))
{
spdlog::error("Either probs or logits is required, but not both");
throw std::exception();
throw std::runtime_error("Either probs or logits is required, but not both");
}

if (probs != nullptr)
{
if (probs->dim() < 1)
{
throw std::exception();
throw std::runtime_error("Probabilities tensor must have at least one dimension");
}
this->probs = *probs;
// 1.21e-7 is used as the epsilon to match PyTorch's Python results as closely
Expand All @@ -33,7 +31,7 @@ Bernoulli::Bernoulli(const torch::Tensor *probs,
{
if (logits->dim() < 1)
{
throw std::exception();
throw std::runtime_error("Logits tensor must have at least one dimension");
}
this->logits = *logits;
this->probs = torch::sigmoid(*logits);
Expand Down
8 changes: 3 additions & 5 deletions src/distributions/categorical.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <c10/util/ArrayRef.h>
#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/distributions/categorical.h"
Expand All @@ -12,15 +11,14 @@ Categorical::Categorical(const torch::Tensor *probs,
{
if ((probs == nullptr) == (logits == nullptr))
{
spdlog::error("Either probs or logits is required, but not both");
throw std::exception();
throw std::runtime_error("Either probs or logits is required, but not both");
}

if (probs != nullptr)
{
if (probs->dim() < 1)
{
throw std::exception();
throw std::runtime_error("Probabilities tensor must have at least one dimension");
}
this->probs = *probs / probs->sum(-1, true);
// 1.21e-7 is used as the epsilon to match PyTorch's Python results as closely
Expand All @@ -32,7 +30,7 @@ Categorical::Categorical(const torch::Tensor *probs,
{
if (logits->dim() < 1)
{
throw std::exception();
throw std::runtime_error("Logits tensor must have at least one dimension");
}
this->logits = *logits - logits->logsumexp(-1, true);
this->probs = torch::softmax(this->logits, -1);
Expand Down
6 changes: 1 addition & 5 deletions src/generators/feed_forward_generator.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <algorithm>
#include <vector>

#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/generators/feed_forward_generator.h"
Expand Down Expand Up @@ -44,10 +43,7 @@ MiniBatch FeedForwardGenerator::next()
{
if (index >= indices.size(0))
{
spdlog::error("No minibatches left in generator. Index {}, minibatch "
"count: {}.",
index, indices.size(0));
throw std::exception();
throw std::runtime_error("No minibatches left in generator.");
}

MiniBatch mini_batch;
Expand Down
6 changes: 1 addition & 5 deletions src/generators/recurrent_generator.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <algorithm>
#include <vector>

#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/generators/recurrent_generator.h"
Expand Down Expand Up @@ -49,10 +48,7 @@ MiniBatch RecurrentGenerator::next()
{
if (index >= indices.size(0))
{
spdlog::error("No minibatches left in generator. Index {}, minibatch "
"count: {}.",
index, indices.size(0));
throw std::exception();
throw std::runtime_error("No minibatches left in generator.");
}

MiniBatch mini_batch;
Expand Down
4 changes: 1 addition & 3 deletions src/model/policy.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/model/policy.h"
Expand Down Expand Up @@ -35,8 +34,7 @@ PolicyImpl::PolicyImpl(ActionSpace action_space, std::shared_ptr<NNBase> base)
}
else
{
spdlog::error("Action space {} not supported", action_space.type);
throw std::exception();
throw std::runtime_error("Action space " + action_space.type + " not supported");
}
register_module("output", output_layer);
}
Expand Down
26 changes: 13 additions & 13 deletions src/storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include <vector>

#include <c10/util/ArrayRef.h>
#include <spdlog/spdlog.h>
#include <torch/torch.h>

#include "cpprl/generators/feed_forward_generator.h"
Expand Down Expand Up @@ -97,13 +96,14 @@ std::unique_ptr<Generator> RolloutStorage::feed_forward_generator(
auto batch_size = num_processes * num_steps;
if (batch_size < num_mini_batch)
{
spdlog::error("PPO needs the number of processes ({}) * the number of "
"steps ({}) = {} to be greater than or equal to the number "
"of minibatches ({})",
num_processes,
num_steps,
num_mini_batch);
throw std::exception();
throw std::runtime_error("PPO needs the number of processes (" +
std::to_string(num_processes) +
") * the number of steps (" +
std::to_string(num_steps) + ") = " +
std::to_string(num_processes * num_steps) +
" to be greater than or equal to the number of minibatches (" +
std::to_string(num_mini_batch) +
")");
}
auto mini_batch_size = batch_size / num_mini_batch;
return std::make_unique<FeedForwardGenerator>(
Expand Down Expand Up @@ -143,11 +143,11 @@ std::unique_ptr<Generator> RolloutStorage::recurrent_generator(
auto num_processes = actions.size(1);
if (num_processes < num_mini_batch)
{
spdlog::error("PPO needs the number of processes ({}) to be greater than or"
" equal to the number of minibatches ({})",
num_processes,
num_mini_batch);
throw std::exception();
throw std::runtime_error("PPO needs the number of processes (" +
std::to_string(num_processes) +
") to be greater than or equal to the number of minibatches (" +
std::to_string(num_mini_batch) +
")");
}
return std::make_unique<RecurrentGenerator>(
num_processes,
Expand Down
12 changes: 2 additions & 10 deletions src/third_party/doctest.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
#define DOCTEST_CONFIG_IMPLEMENT
#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN

#include <spdlog/spdlog.h>

#include "third_party/doctest.h"

int main(int argc, char **argv)
{
spdlog::set_level(spdlog::level::off);
return doctest::Context(argc, argv).run();
}
#include "third_party/doctest.h"