Skip to content

Commit

Permalink
merge upstream - make error
Browse files Browse the repository at this point in the history
  • Loading branch information
alberthli committed Feb 9, 2024
2 parents 9c71bde + 1ecabed commit f6c5ce5
Show file tree
Hide file tree
Showing 68 changed files with 4,662 additions and 308 deletions.
11 changes: 6 additions & 5 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ jobs:
-DCMAKE_C_COMPILER:STRING=clang-12
-DCMAKE_CXX_COMPILER:STRING=clang++-12
-DMJPC_BUILD_GRPC_SERVICE:BOOL=ON
additional_targets: "agent_server direct_server filter_server"
tmpdir: "/tmp"
- os: macos-12
cmake_args: >-
-G Ninja
-DMJPC_BUILD_GRPC_SERVICE:BOOL=ON
additional_targets: "agent_server direct_server filter_server"
tmpdir: "/tmp"

name: "MuJoCo MPC on ${{ matrix.os }} ${{ matrix.additional_label }}"
Expand All @@ -43,9 +41,10 @@ jobs:
libxrandr-dev
libxi-dev
ninja-build
zlib1g-dev
- name: Prepare macOS
if: ${{ runner.os == 'macOS' }}
run: brew install ninja
run: brew install ninja zlib
- name: Prepare Windows
if: ${{ runner.os == 'Windows' }}
# Install llvm 16 manually, remove after
Expand All @@ -68,10 +67,12 @@ jobs:
$cmake_extra_args
- name: Build MuJoCo MPC
working-directory: build
run: cmake --build . --config=Release ${{ matrix.cmake_build_args }} --target mjpc agent_test cost_derivatives_test norm_test rollout_test threadpool_test trajectory_test utilities_test direct_force_test direct_optimize_test direct_parameter_test direct_sensor_test direct_trajectory_test direct_utilities_test batch_filter_test batch_prior_test kalman_test unscented_test ${{ matrix.additional_targets }}
run: cmake --build . --config=Release ${{ matrix.cmake_build_args }}
- name: Test MuJoCo MPC
working-directory: build
run: ctest -C Release --output-on-failure .
run: >
cd mjpc/test &&
ctest -C Release --output-on-failure .
- name: Notify team chat
shell: bash
env:
Expand Down
18 changes: 7 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,12 @@ findorfetch(
EXCLUDE_FROM_ALL
)

# TODO(nimrod): Update to the latest version of abseil, or use the one defined
# by MuJoCo, once grpc fix their build issues.
set(MUJOCO_DEP_VERSION_abseil
c8a2f92586fe9b4e1aff049108f5db8064924d8e # LTS 20230125.1
fb3621f4f897824c0dbe0615fa94543df6192f30 # LTS 20230802.1
CACHE STRING "Version of `abseil` to be fetched."
)

set(MUJOCO_DEP_VERSION_glfw3
7482de6071d21db77a7236155da44c172a7f6c9e # 3.3.8
CACHE STRING "Version of `glfw` to be fetched."
)

set(MJPC_DEP_VERSION_lodepng
b4ed2cd7ecf61d29076169b49199371456d4f90b
CACHE STRING "Version of `lodepng` to be fetched."
FORCE
)

set(BUILD_SHARED_LIBS_OLD ${BUILD_SHARED_LIBS})
Expand Down Expand Up @@ -118,6 +111,9 @@ findorfetch(

set(ABSL_PROPAGATE_CXX_STD ON)
set(ABSL_BUILD_TESTING OFF)
# ABSL_ENABLE_INSTALL is needed for
# https://github.com/protocolbuffers/protobuf/issues/12185#issuecomment-1594685860
set(ABSL_ENABLE_INSTALL ON)
findorfetch(
USE_SYSTEM_PACKAGE
OFF
Expand Down
10 changes: 5 additions & 5 deletions cmake/MujocoLinkOptions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function(get_mujoco_extra_link_options OUTPUT_VAR)
set(EXTRA_LINK_OPTIONS)

if(WIN32)
set(CMAKE_REQUIRED_LINK_OPTIONS "-fuse-ld=lld-link")
set(CMAKE_REQUIRED_FLAGS "-fuse-ld=lld-link")
check_c_source_compiles("int main() {}" SUPPORTS_LLD)
if(SUPPORTS_LLD)
set(EXTRA_LINK_OPTIONS
Expand All @@ -34,24 +34,24 @@ function(get_mujoco_extra_link_options OUTPUT_VAR)
)
endif()
else()
set(CMAKE_REQUIRED_LINK_OPTIONS "-fuse-ld=lld")
set(CMAKE_REQUIRED_FLAGS "-fuse-ld=lld")
check_c_source_compiles("int main() {}" SUPPORTS_LLD)
if(SUPPORTS_LLD)
set(EXTRA_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} -fuse-ld=lld)
else()
set(CMAKE_REQUIRED_LINK_OPTIONS "-fuse-ld=gold")
set(CMAKE_REQUIRED_FLAGS "-fuse-ld=gold")
check_c_source_compiles("int main() {}" SUPPORTS_GOLD)
if(SUPPORTS_GOLD)
set(EXTRA_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} -fuse-ld=gold)
endif()
endif()

set(CMAKE_REQUIRED_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} "-Wl,--gc-sections")
set(CMAKE_REQUIRED_FLAGS ${EXTRA_LINK_OPTIONS} "-Wl,--gc-sections")
check_c_source_compiles("int main() {}" SUPPORTS_GC_SECTIONS)
if(SUPPORTS_GC_SECTIONS)
set(EXTRA_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} -Wl,--gc-sections)
else()
set(CMAKE_REQUIRED_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} "-Wl,-dead_strip")
set(CMAKE_REQUIRED_FLAGS ${EXTRA_LINK_OPTIONS} "-Wl,-dead_strip")
check_c_source_compiles("int main() {}" SUPPORTS_DEAD_STRIP)
if(SUPPORTS_DEAD_STRIP)
set(EXTRA_LINK_OPTIONS ${EXTRA_LINK_OPTIONS} -Wl,-dead_strip)
Expand Down
31 changes: 31 additions & 0 deletions mjpc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ add_library(
tasks/acrobot/acrobot.h
tasks/allegro/allegro.cc
tasks/allegro/allegro.h
tasks/bimanual/bimanual.cc
tasks/bimanual/bimanual.h
tasks/cartpole/cartpole.cc
tasks/cartpole/cartpole.h
tasks/cube/solve.cc
Expand Down Expand Up @@ -85,6 +87,8 @@ add_library(
planners/cross_entropy/planner.h
planners/robust/robust_planner.cc
planners/robust/robust_planner.h
planners/sample_gradient/planner.cc
planners/sample_gradient/planner.h
planners/sampling/planner.cc
planners/sampling/planner.h
planners/sampling/policy.cc
Expand Down Expand Up @@ -122,6 +126,8 @@ add_library(
direct/trajectory.h
direct/model_parameters.cc
direct/model_parameters.h
spline/spline.cc
spline/spline.h
app.cc
app.h
norm.cc
Expand All @@ -138,8 +144,11 @@ target_compile_definitions(libmjpc PRIVATE MJSIMULATE_STATIC)
target_link_libraries(
libmjpc
absl::any_invocable
absl::check
absl::flat_hash_map
absl::log
absl::random_random
absl::span
glfw
lodepng
mujoco::mujoco
Expand Down Expand Up @@ -177,6 +186,28 @@ if(APPLE)
target_link_libraries(mjpc "-framework Cocoa")
endif()

add_executable(
testspeed
testspeed_app.cc
testspeed.h
testspeed.cc
)
target_link_libraries(
testspeed
absl::flags
absl::flags_parse
absl::random_random
absl::strings
libmjpc
mujoco::mujoco
threadpool
Threads::Threads
)
target_include_directories(testspeed PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
target_compile_options(testspeed PUBLIC ${MJPC_COMPILE_OPTIONS})
target_link_options(testspeed PRIVATE ${MJPC_LINK_OPTIONS})
target_compile_definitions(testspeed PRIVATE MJSIMULATE_STATIC)

add_subdirectory(tasks)

if(BUILD_TESTING AND MJPC_BUILD_TESTS)
Expand Down
13 changes: 13 additions & 0 deletions mjpc/agent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ void Agent::Initialize(const mjModel* model) {
if (model_) mj_deleteModel(model_);
model_ = mj_copyModel(nullptr, model); // agent's copy of model

// check for limits on all actuators
int num_missing = 0;
for (int i = 0; i < model_->nu; i++) {
if (!model_->actuator_ctrllimited[i]) {
num_missing++;
printf("%s (actuator %i) missing limits\n",
model_->names + model_->name_actuatoradr[i], i);
}
}
if (num_missing > 0) {
mju_error("Ctrl limits required for all actuators.\n");
}

// planner
planner_ = GetNumberOrDefault(0, model, "agent_planner");

Expand Down
1 change: 1 addition & 0 deletions mjpc/agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class Agent {
mjpc::Planner& ActivePlanner() const { return *planners_[planner_]; }
mjpc::Estimator& ActiveEstimator() const { return *estimators_[estimator_]; }
int ActiveEstimatorIndex() const { return estimator_; }
double ComputeTime() const { return agent_compute_time_; }
Task* ActiveTask() const { return tasks_[active_task_id_].get(); }
// a residual function that can be used from trajectory rollouts. must only
// be used from trajectory rollout threads (no locking).
Expand Down
9 changes: 4 additions & 5 deletions mjpc/direct/direct.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MJPC_DIRECT_OPTIMIZER_H_
#define MJPC_DIRECT_OPTIMIZER_H_
#ifndef MJPC_DIRECT_DIRECT_H_
#define MJPC_DIRECT_DIRECT_H_

#include <memory>
#include <mutex>
#include <string>
#include <vector>

Expand Down Expand Up @@ -65,7 +64,7 @@ class Direct {
: model_parameters_(LoadModelParameters()), pool_(num_threads) {}

// constructor
Direct(const mjModel* model, int length = 3, int max_history = 0);
explicit Direct(const mjModel* model, int length = 3, int max_history = 0);

// destructor
virtual ~Direct() {
Expand Down Expand Up @@ -510,4 +509,4 @@ std::string StatusString(int code);

} // namespace mjpc

#endif // MJPC_DIRECT_OPTIMIZER_H_
#endif // MJPC_DIRECT_DIRECT_H_
1 change: 0 additions & 1 deletion mjpc/estimators/batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "mjpc/array_safety.h"
#include "mjpc/estimators/estimator.h"
#include "mjpc/direct/direct.h"
#include "mjpc/norm.h"
#include "mjpc/threadpool.h"
#include "mjpc/utilities.h"

Expand Down
9 changes: 5 additions & 4 deletions mjpc/grpc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ set(BUILD_SHARED_LIBS
CACHE INTERNAL "Build SHARED libraries"
)

find_package(ZLIB REQUIRED)
set(gRPC_ZLIB_PROVIDER "package" CACHE INTERNAL "")
set(ZLIB_BUILD_EXAMPLES OFF)

findorfetch(
USE_SYSTEM_PACKAGE
OFF
Expand All @@ -30,14 +34,11 @@ findorfetch(
GIT_REPO
https://github.com/grpc/grpc
GIT_TAG
v1.53.0
v1.60.1
TARGETS
gRPC
)

find_package(ZLIB REQUIRED)
set(gRPC_ZLIB_PROVIDER "package" CACHE INTERNAL "")
set(ZLIB_BUILD_EXAMPLES OFF)
set(_PROTOBUF_LIBPROTOBUF libprotobuf)
set(_REFLECTION grpc++_reflection)
set(_PROTOBUF_PROTOC $<TARGET_FILE:protoc>)
Expand Down
12 changes: 12 additions & 0 deletions mjpc/grpc/agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ service Agent {
returns (GetTaskParametersResponse);
// Set cost weights.
rpc SetCostWeights(SetCostWeightsRequest) returns (SetCostWeightsResponse);
// Get cost term residuals.
rpc GetResiduals(GetResidualsRequest) returns (GetResidualsResponse);
// Get cost term values.
rpc GetCostValuesAndWeights(GetCostValuesAndWeightsRequest)
returns (GetCostValuesAndWeightsResponse);
Expand Down Expand Up @@ -113,6 +115,16 @@ message GetActionResponse {
repeated float action = 1 [packed = true];
}

message GetResidualsRequest {}

message Residual {
repeated double values = 1;
}

message GetResidualsResponse {
map<string, Residual> values = 1;
}

message GetCostValuesAndWeightsRequest {}

message ValueAndWeight {
Expand Down
17 changes: 17 additions & 0 deletions mjpc/grpc/agent_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using ::agent::GetAllModesRequest;
using ::agent::GetAllModesResponse;
using ::agent::GetBestTrajectoryRequest;
using ::agent::GetBestTrajectoryResponse;
using ::agent::GetResidualsRequest;
using ::agent::GetResidualsResponse;
using ::agent::GetCostValuesAndWeightsRequest;
using ::agent::GetCostValuesAndWeightsResponse;
using ::agent::GetModeRequest;
Expand Down Expand Up @@ -118,6 +120,11 @@ grpc::Status AgentService::Init(grpc::ServerContext* context,
model = mj_copyModel(nullptr, agent_model);
data_ = mj_makeData(model);
rollout_data_.reset(mj_makeData(model));
int home_id = mj_name2id(model, mjOBJ_KEY, "home");
if (home_id >= 0) {
mj_resetDataKeyframe(model, data_, home_id);
mj_resetDataKeyframe(model, rollout_data_.get(), home_id);
}
mjcb_sensor = residual_sensor_callback;

agent_.SetState(data_);
Expand Down Expand Up @@ -176,6 +183,16 @@ grpc::Status AgentService::GetAction(grpc::ServerContext* context,
request, &agent_, model, rollout_data_.get(), &rollout_state_, response);
}

grpc::Status AgentService::GetResiduals(
grpc::ServerContext* context, const GetResidualsRequest* request,
GetResidualsResponse* response) {
if (!Initialized()) {
return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."};
}
return grpc_agent_util::GetResiduals(request, &agent_, model,
data_, response);
}

grpc::Status AgentService::GetCostValuesAndWeights(
grpc::ServerContext* context, const GetCostValuesAndWeightsRequest* request,
GetCostValuesAndWeightsResponse* response) {
Expand Down
5 changes: 5 additions & 0 deletions mjpc/grpc/agent_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class AgentService final : public agent::Agent::Service {
const agent::GetActionRequest* request,
agent::GetActionResponse* response) override;

grpc::Status GetResiduals(
grpc::ServerContext* context,
const agent::GetResidualsRequest* request,
agent::GetResidualsResponse* response) override;

grpc::Status GetCostValuesAndWeights(
grpc::ServerContext* context,
const agent::GetCostValuesAndWeightsRequest* request,
Expand Down
12 changes: 12 additions & 0 deletions mjpc/grpc/agent_service_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,4 +360,16 @@ TEST_F(AgentServiceTest, GetAllModes_Works) {
EXPECT_EQ(response.mode_names()[0], "default_mode");
}

TEST_F(AgentServiceTest, GetResiduals_Works) {
RunAndCheckInit("Cartpole", nullptr);

grpc::ClientContext context;

agent::GetResidualsRequest request;
agent::GetResidualsResponse response;
grpc::Status status = stub->GetResiduals(&context, request, &response);

EXPECT_TRUE(status.ok());
}

} // namespace mjpc::agent_grpc
Loading

0 comments on commit f6c5ce5

Please sign in to comment.