Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into interpolate_derivat…
Browse files Browse the repository at this point in the history
…ives
  • Loading branch information
thowell committed Feb 9, 2024
2 parents d7bf37e + 1ecabed commit c392c2a
Show file tree
Hide file tree
Showing 29 changed files with 1,424 additions and 38 deletions.
11 changes: 6 additions & 5 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ jobs:
-DCMAKE_C_COMPILER:STRING=clang-12
-DCMAKE_CXX_COMPILER:STRING=clang++-12
-DMJPC_BUILD_GRPC_SERVICE:BOOL=ON
additional_targets: "agent_server direct_server filter_server"
tmpdir: "/tmp"
- os: macos-12
cmake_args: >-
-G Ninja
-DMJPC_BUILD_GRPC_SERVICE:BOOL=ON
additional_targets: "agent_server direct_server filter_server"
tmpdir: "/tmp"

name: "MuJoCo MPC on ${{ matrix.os }} ${{ matrix.additional_label }}"
Expand All @@ -43,9 +41,10 @@ jobs:
libxrandr-dev
libxi-dev
ninja-build
zlib1g-dev
- name: Prepare macOS
if: ${{ runner.os == 'macOS' }}
run: brew install ninja
run: brew install ninja zlib
- name: Prepare Windows
if: ${{ runner.os == 'Windows' }}
# Install llvm 16 manually, remove after
Expand All @@ -68,10 +67,12 @@ jobs:
$cmake_extra_args
- name: Build MuJoCo MPC
working-directory: build
run: cmake --build . --config=Release ${{ matrix.cmake_build_args }} --target mjpc agent_test agent_utilities_test cost_derivatives_test norm_test rollout_test threadpool_test trajectory_test direct_force_test direct_optimize_test direct_parameter_test direct_sensor_test direct_trajectory_test direct_utilities_test batch_filter_test batch_prior_test kalman_test unscented_test cubic_test gradient_planner_test gradient_test linear_test zero_test backward_pass_test ilqg_test robust_planner_test sampling_planner_test state_test task_test utilities_test ${{ matrix.additional_targets }}
run: cmake --build . --config=Release ${{ matrix.cmake_build_args }}
- name: Test MuJoCo MPC
working-directory: build
run: ctest -C Release --output-on-failure .
run: >
cd mjpc/test &&
ctest -C Release --output-on-failure .
- name: Notify team chat
shell: bash
env:
Expand Down
18 changes: 7 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,12 @@ findorfetch(
EXCLUDE_FROM_ALL
)

# TODO(nimrod): Update to the latest version of abseil, or use the one defined
# by MuJoCo, once grpc fix their build issues.
set(MUJOCO_DEP_VERSION_abseil
c8a2f92586fe9b4e1aff049108f5db8064924d8e # LTS 20230125.1
fb3621f4f897824c0dbe0615fa94543df6192f30 # LTS 20230802.1
CACHE STRING "Version of `abseil` to be fetched."
)

set(MUJOCO_DEP_VERSION_glfw3
7482de6071d21db77a7236155da44c172a7f6c9e # 3.3.8
CACHE STRING "Version of `glfw` to be fetched."
)

set(MJPC_DEP_VERSION_lodepng
b4ed2cd7ecf61d29076169b49199371456d4f90b
CACHE STRING "Version of `lodepng` to be fetched."
FORCE
)

set(BUILD_SHARED_LIBS_OLD ${BUILD_SHARED_LIBS})
Expand Down Expand Up @@ -118,6 +111,9 @@ findorfetch(

set(ABSL_PROPAGATE_CXX_STD ON)
set(ABSL_BUILD_TESTING OFF)
# ABSL_ENABLE_INSTALL is needed for
# https://github.com/protocolbuffers/protobuf/issues/12185#issuecomment-1594685860
set(ABSL_ENABLE_INSTALL ON)
findorfetch(
USE_SYSTEM_PACKAGE
OFF
Expand Down
7 changes: 7 additions & 0 deletions mjpc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ add_library(
tasks/tasks.h
tasks/acrobot/acrobot.cc
tasks/acrobot/acrobot.h
tasks/bimanual/bimanual.cc
tasks/bimanual/bimanual.h
tasks/cartpole/cartpole.cc
tasks/cartpole/cartpole.h
tasks/cube/solve.cc
Expand Down Expand Up @@ -122,6 +124,8 @@ add_library(
direct/trajectory.h
direct/model_parameters.cc
direct/model_parameters.h
spline/spline.cc
spline/spline.h
app.cc
app.h
norm.cc
Expand All @@ -138,8 +142,11 @@ target_compile_definitions(libmjpc PRIVATE MJSIMULATE_STATIC)
target_link_libraries(
libmjpc
absl::any_invocable
absl::check
absl::flat_hash_map
absl::log
absl::random_random
absl::span
glfw
lodepng
mujoco::mujoco
Expand Down
9 changes: 5 additions & 4 deletions mjpc/grpc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ set(BUILD_SHARED_LIBS
CACHE INTERNAL "Build SHARED libraries"
)

find_package(ZLIB REQUIRED)
set(gRPC_ZLIB_PROVIDER "package" CACHE INTERNAL "")
set(ZLIB_BUILD_EXAMPLES OFF)

findorfetch(
USE_SYSTEM_PACKAGE
OFF
Expand All @@ -30,14 +34,11 @@ findorfetch(
GIT_REPO
https://github.com/grpc/grpc
GIT_TAG
v1.53.0
v1.60.1
TARGETS
gRPC
)

find_package(ZLIB REQUIRED)
set(gRPC_ZLIB_PROVIDER "package" CACHE INTERNAL "")
set(ZLIB_BUILD_EXAMPLES OFF)
set(_PROTOBUF_LIBPROTOBUF libprotobuf)
set(_REFLECTION grpc++_reflection)
set(_PROTOBUF_PROTOC $<TARGET_FILE:protoc>)
Expand Down
12 changes: 12 additions & 0 deletions mjpc/grpc/agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ service Agent {
returns (GetTaskParametersResponse);
// Set cost weights.
rpc SetCostWeights(SetCostWeightsRequest) returns (SetCostWeightsResponse);
// Get cost term residuals.
rpc GetResiduals(GetResidualsRequest) returns (GetResidualsResponse);
// Get cost term values.
rpc GetCostValuesAndWeights(GetCostValuesAndWeightsRequest)
returns (GetCostValuesAndWeightsResponse);
Expand Down Expand Up @@ -113,6 +115,16 @@ message GetActionResponse {
repeated float action = 1 [packed = true];
}

message GetResidualsRequest {}

message Residual {
repeated double values = 1;
}

message GetResidualsResponse {
map<string, Residual> values = 1;
}

message GetCostValuesAndWeightsRequest {}

message ValueAndWeight {
Expand Down
12 changes: 12 additions & 0 deletions mjpc/grpc/agent_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using ::agent::GetAllModesRequest;
using ::agent::GetAllModesResponse;
using ::agent::GetBestTrajectoryRequest;
using ::agent::GetBestTrajectoryResponse;
using ::agent::GetResidualsRequest;
using ::agent::GetResidualsResponse;
using ::agent::GetCostValuesAndWeightsRequest;
using ::agent::GetCostValuesAndWeightsResponse;
using ::agent::GetModeRequest;
Expand Down Expand Up @@ -181,6 +183,16 @@ grpc::Status AgentService::GetAction(grpc::ServerContext* context,
request, &agent_, model, rollout_data_.get(), &rollout_state_, response);
}

grpc::Status AgentService::GetResiduals(
grpc::ServerContext* context, const GetResidualsRequest* request,
GetResidualsResponse* response) {
if (!Initialized()) {
return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."};
}
return grpc_agent_util::GetResiduals(request, &agent_, model,
data_, response);
}

grpc::Status AgentService::GetCostValuesAndWeights(
grpc::ServerContext* context, const GetCostValuesAndWeightsRequest* request,
GetCostValuesAndWeightsResponse* response) {
Expand Down
5 changes: 5 additions & 0 deletions mjpc/grpc/agent_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class AgentService final : public agent::Agent::Service {
const agent::GetActionRequest* request,
agent::GetActionResponse* response) override;

grpc::Status GetResiduals(
grpc::ServerContext* context,
const agent::GetResidualsRequest* request,
agent::GetResidualsResponse* response) override;

grpc::Status GetCostValuesAndWeights(
grpc::ServerContext* context,
const agent::GetCostValuesAndWeightsRequest* request,
Expand Down
12 changes: 12 additions & 0 deletions mjpc/grpc/agent_service_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,4 +360,16 @@ TEST_F(AgentServiceTest, GetAllModes_Works) {
EXPECT_EQ(response.mode_names()[0], "default_mode");
}

TEST_F(AgentServiceTest, GetResiduals_Works) {
RunAndCheckInit("Cartpole", nullptr);

grpc::ClientContext context;

agent::GetResidualsRequest request;
agent::GetResidualsResponse response;
grpc::Status status = stub->GetResiduals(&context, request, &response);

EXPECT_TRUE(status.ok());
}

} // namespace mjpc::agent_grpc
31 changes: 31 additions & 0 deletions mjpc/grpc/grpc_agent_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ using ::agent::GetActionRequest;
using ::agent::GetActionResponse;
using ::agent::GetAllModesRequest;
using ::agent::GetAllModesResponse;
using ::agent::GetResidualsRequest;
using ::agent::GetResidualsResponse;
using ::agent::GetCostValuesAndWeightsRequest;
using ::agent::GetCostValuesAndWeightsResponse;
using ::agent::GetModeRequest;
Expand All @@ -58,6 +60,7 @@ using ::agent::SetCostWeightsRequest;
using ::agent::SetModeRequest;
using ::agent::SetStateRequest;
using ::agent::SetTaskParametersRequest;
using ::agent::Residual;
using ::agent::ValueAndWeight;

grpc::Status GetState(const mjModel* model, const mjData* data,
Expand Down Expand Up @@ -226,6 +229,34 @@ grpc::Status GetAction(const GetActionRequest* request,
return grpc::Status::OK;
}

grpc::Status GetResiduals(
const GetResidualsRequest* request, const mjpc::Agent* agent,
const mjModel* model, mjData* data,
GetResidualsResponse* response) {
const mjModel* agent_model = agent->GetModel();
const mjpc::Task* task = agent->ActiveTask();
std::vector<double> residuals(task->num_residual, 0); // scratch space
task->Residual(model, data, residuals.data());
std::vector<int> dim_norm_residual = task->dim_norm_residual;

int residual_shift = 0;
for (int i = 0; i < task->num_term; i++) {
CHECK_EQ(agent_model->sensor_type[i], mjSENS_USER);
std::string_view sensor_name(agent_model->names +
agent_model->name_sensoradr[i]);

std::vector<double> sensor_residual_values(
residuals.begin() + residual_shift,
residuals.begin() + residual_shift + dim_norm_residual[i]);
Residual sensor_residual;
sensor_residual.mutable_values()->Assign(sensor_residual_values.begin(),
sensor_residual_values.end());
(*response->mutable_values())[sensor_name] = sensor_residual;
residual_shift += dim_norm_residual[i];
}
return grpc::Status::OK;
}

grpc::Status GetCostValuesAndWeights(
const GetCostValuesAndWeightsRequest* request, const mjpc::Agent* agent,
const mjModel* model, mjData* data,
Expand Down
4 changes: 4 additions & 0 deletions mjpc/grpc/grpc_agent_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ grpc::Status GetAction(const agent::GetActionRequest* request,
const mjModel* model, mjData* rollout_data,
mjpc::State* rollout_state,
agent::GetActionResponse* response);
grpc::Status GetResiduals(
const agent::GetResidualsRequest* request,
const mjpc::Agent* agent, const mjModel* model, mjData* data,
agent::GetResidualsResponse* response);
grpc::Status GetCostValuesAndWeights(
const agent::GetCostValuesAndWeightsRequest* request,
const mjpc::Agent* agent, const mjModel* model, mjData* data,
Expand Down
14 changes: 14 additions & 0 deletions mjpc/grpc/ui_agent_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ using ::agent::GetActionRequest;
using ::agent::GetActionResponse;
using ::agent::GetModeRequest;
using ::agent::GetModeResponse;
using ::agent::GetResidualsRequest;
using ::agent::GetResidualsResponse;
using ::agent::GetCostValuesAndWeightsRequest;
using ::agent::GetCostValuesAndWeightsResponse;
using ::agent::GetStateRequest;
Expand Down Expand Up @@ -125,6 +127,18 @@ grpc::Status UiAgentService::GetAction(grpc::ServerContext* context,
});
}

grpc::Status UiAgentService::GetResiduals(
grpc::ServerContext* context, const GetResidualsRequest* request,
GetResidualsResponse* response) {
return RunBeforeStep(
context, [request, response](mjpc::Agent* agent, const mjModel* model,
mjData* data) {
return grpc_agent_util::GetResiduals(request, agent, model,
data, response);
});
}


grpc::Status UiAgentService::GetCostValuesAndWeights(
grpc::ServerContext* context, const GetCostValuesAndWeightsRequest* request,
GetCostValuesAndWeightsResponse* response) {
Expand Down
5 changes: 5 additions & 0 deletions mjpc/grpc/ui_agent_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class UiAgentService final : public agent::Agent::Service {
const agent::GetActionRequest* request,
agent::GetActionResponse* response) override;

grpc::Status GetResiduals(
grpc::ServerContext* context,
const agent::GetResidualsRequest* request,
agent::GetResidualsResponse* response) override;

grpc::Status GetCostValuesAndWeights(
grpc::ServerContext* context,
const agent::GetCostValuesAndWeightsRequest* request,
Expand Down
10 changes: 6 additions & 4 deletions mjpc/planners/sample_gradient/planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@

#include "mjpc/planners/sample_gradient/planner.h"

#include <absl/random/random.h>
#include <mujoco/mujoco.h>

#include <algorithm>
#include <chrono>
#include <cmath>
#include <mutex>
#include <shared_mutex>

#include <absl/random/random.h>
#include <mujoco/mujoco.h>
#include "mjpc/array_safety.h"
#include "mjpc/planners/planner.h"
#include "mjpc/planners/policy.h"
#include "mjpc/planners/sampling/planner.h"
#include "mjpc/states/state.h"
#include "mjpc/task.h"
#include "mjpc/threadpool.h"
#include "mjpc/trajectory.h"
#include "mjpc/utilities.h"

Expand Down
Loading

0 comments on commit c392c2a

Please sign in to comment.