Skip to content

Commit

Permalink
Add API for MJPC cost residuals.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604774074
Change-Id: I2b4463daf24db079591b58fb56dc3c1eb42ca5f9
  • Loading branch information
copybara-github committed Feb 6, 2024
1 parent d81174c commit e28230c
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 0 deletions.
12 changes: 12 additions & 0 deletions mjpc/grpc/agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ service Agent {
returns (GetTaskParametersResponse);
// Set cost weights.
rpc SetCostWeights(SetCostWeightsRequest) returns (SetCostWeightsResponse);
// Get cost term residuals.
rpc GetResiduals(GetResidualsRequest) returns (GetResidualsResponse);
// Get cost term values.
rpc GetCostValuesAndWeights(GetCostValuesAndWeightsRequest)
returns (GetCostValuesAndWeightsResponse);
Expand Down Expand Up @@ -113,6 +115,16 @@ message GetActionResponse {
repeated float action = 1 [packed = true];
}

message GetResidualsRequest {}

message Residual {
repeated double values = 1;
}

message GetResidualsResponse {
map<string, Residual> values = 1;
}

message GetCostValuesAndWeightsRequest {}

message ValueAndWeight {
Expand Down
12 changes: 12 additions & 0 deletions mjpc/grpc/agent_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using ::agent::GetAllModesRequest;
using ::agent::GetAllModesResponse;
using ::agent::GetBestTrajectoryRequest;
using ::agent::GetBestTrajectoryResponse;
using ::agent::GetResidualsRequest;
using ::agent::GetResidualsResponse;
using ::agent::GetCostValuesAndWeightsRequest;
using ::agent::GetCostValuesAndWeightsResponse;
using ::agent::GetModeRequest;
Expand Down Expand Up @@ -181,6 +183,16 @@ grpc::Status AgentService::GetAction(grpc::ServerContext* context,
request, &agent_, model, rollout_data_.get(), &rollout_state_, response);
}

grpc::Status AgentService::GetResiduals(
grpc::ServerContext* context, const GetResidualsRequest* request,
GetResidualsResponse* response) {
if (!Initialized()) {
return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."};
}
return grpc_agent_util::GetResiduals(request, &agent_, model,
data_, response);
}

grpc::Status AgentService::GetCostValuesAndWeights(
grpc::ServerContext* context, const GetCostValuesAndWeightsRequest* request,
GetCostValuesAndWeightsResponse* response) {
Expand Down
5 changes: 5 additions & 0 deletions mjpc/grpc/agent_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class AgentService final : public agent::Agent::Service {
const agent::GetActionRequest* request,
agent::GetActionResponse* response) override;

grpc::Status GetResiduals(
grpc::ServerContext* context,
const agent::GetResidualsRequest* request,
agent::GetResidualsResponse* response) override;

grpc::Status GetCostValuesAndWeights(
grpc::ServerContext* context,
const agent::GetCostValuesAndWeightsRequest* request,
Expand Down
12 changes: 12 additions & 0 deletions mjpc/grpc/agent_service_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,4 +360,16 @@ TEST_F(AgentServiceTest, GetAllModes_Works) {
EXPECT_EQ(response.mode_names()[0], "default_mode");
}

TEST_F(AgentServiceTest, GetResiduals_Works) {
RunAndCheckInit("Cartpole", nullptr);

grpc::ClientContext context;

agent::GetResidualsRequest request;
agent::GetResidualsResponse response;
grpc::Status status = stub->GetResiduals(&context, request, &response);

EXPECT_TRUE(status.ok());
}

} // namespace mjpc::agent_grpc
31 changes: 31 additions & 0 deletions mjpc/grpc/grpc_agent_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ using ::agent::GetActionRequest;
using ::agent::GetActionResponse;
using ::agent::GetAllModesRequest;
using ::agent::GetAllModesResponse;
using ::agent::GetResidualsRequest;
using ::agent::GetResidualsResponse;
using ::agent::GetCostValuesAndWeightsRequest;
using ::agent::GetCostValuesAndWeightsResponse;
using ::agent::GetModeRequest;
Expand All @@ -58,6 +60,7 @@ using ::agent::SetCostWeightsRequest;
using ::agent::SetModeRequest;
using ::agent::SetStateRequest;
using ::agent::SetTaskParametersRequest;
using ::agent::Residual;
using ::agent::ValueAndWeight;

grpc::Status GetState(const mjModel* model, const mjData* data,
Expand Down Expand Up @@ -226,6 +229,34 @@ grpc::Status GetAction(const GetActionRequest* request,
return grpc::Status::OK;
}

grpc::Status GetResiduals(
const GetResidualsRequest* request, const mjpc::Agent* agent,
const mjModel* model, mjData* data,
GetResidualsResponse* response) {
const mjModel* agent_model = agent->GetModel();
const mjpc::Task* task = agent->ActiveTask();
std::vector<double> residuals(task->num_residual, 0); // scratch space
task->Residual(model, data, residuals.data());
std::vector<int> dim_norm_residual = task->dim_norm_residual;

int residual_shift = 0;
for (int i = 0; i < task->num_term; i++) {
CHECK_EQ(agent_model->sensor_type[i], mjSENS_USER);
std::string_view sensor_name(agent_model->names +
agent_model->name_sensoradr[i]);

std::vector<double> sensor_residual_values(
residuals.begin() + residual_shift,
residuals.begin() + residual_shift + dim_norm_residual[i]);
Residual sensor_residual;
sensor_residual.mutable_values()->Assign(sensor_residual_values.begin(),
sensor_residual_values.end());
(*response->mutable_values())[sensor_name] = sensor_residual;
residual_shift += dim_norm_residual[i];
}
return grpc::Status::OK;
}

grpc::Status GetCostValuesAndWeights(
const GetCostValuesAndWeightsRequest* request, const mjpc::Agent* agent,
const mjModel* model, mjData* data,
Expand Down
4 changes: 4 additions & 0 deletions mjpc/grpc/grpc_agent_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ grpc::Status GetAction(const agent::GetActionRequest* request,
const mjModel* model, mjData* rollout_data,
mjpc::State* rollout_state,
agent::GetActionResponse* response);
grpc::Status GetResiduals(
const agent::GetResidualsRequest* request,
const mjpc::Agent* agent, const mjModel* model, mjData* data,
agent::GetResidualsResponse* response);
grpc::Status GetCostValuesAndWeights(
const agent::GetCostValuesAndWeightsRequest* request,
const mjpc::Agent* agent, const mjModel* model, mjData* data,
Expand Down
14 changes: 14 additions & 0 deletions mjpc/grpc/ui_agent_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ using ::agent::GetActionRequest;
using ::agent::GetActionResponse;
using ::agent::GetModeRequest;
using ::agent::GetModeResponse;
using ::agent::GetResidualsRequest;
using ::agent::GetResidualsResponse;
using ::agent::GetCostValuesAndWeightsRequest;
using ::agent::GetCostValuesAndWeightsResponse;
using ::agent::GetStateRequest;
Expand Down Expand Up @@ -125,6 +127,18 @@ grpc::Status UiAgentService::GetAction(grpc::ServerContext* context,
});
}

grpc::Status UiAgentService::GetResiduals(
grpc::ServerContext* context, const GetResidualsRequest* request,
GetResidualsResponse* response) {
return RunBeforeStep(
context, [request, response](mjpc::Agent* agent, const mjModel* model,
mjData* data) {
return grpc_agent_util::GetResiduals(request, agent, model,
data, response);
});
}


grpc::Status UiAgentService::GetCostValuesAndWeights(
grpc::ServerContext* context, const GetCostValuesAndWeightsRequest* request,
GetCostValuesAndWeightsResponse* response) {
Expand Down
5 changes: 5 additions & 0 deletions mjpc/grpc/ui_agent_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class UiAgentService final : public agent::Agent::Service {
const agent::GetActionRequest* request,
agent::GetActionResponse* response) override;

grpc::Status GetResiduals(
grpc::ServerContext* context,
const agent::GetResidualsRequest* request,
agent::GetResidualsResponse* response) override;

grpc::Status GetCostValuesAndWeights(
grpc::ServerContext* context,
const agent::GetCostValuesAndWeightsRequest* request,
Expand Down
5 changes: 5 additions & 0 deletions python/mujoco_mpc/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def get_cost_term_values(self) -> dict[str, float]:
for name, value_weight in terms.values_weights.items()
}

def get_residuals(self) -> dict[str, Sequence[float]]:
residuals = self.stub.GetResiduals(agent_pb2.GetResidualsRequest())
return {name: residual.values
for name, residual in residuals.values.items()}

def planner_step(self):
"""Send a planner request."""
planner_step_request = agent_pb2.PlannerStepRequest()
Expand Down
4 changes: 4 additions & 0 deletions python/mujoco_mpc/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ def test_get_cost_weights(self):
terms = list(terms_dict.values())
self.assertFalse(np.any(np.isclose(terms, 0, rtol=0, atol=1e-4)))

residuals_dict = agent.get_residuals()
residuals = list(residuals_dict.values())
self.assertFalse(np.any(np.isclose(residuals, 0, rtol=0, atol=1e-4)))

def test_set_state_with_lists(self):
model_path = (
pathlib.Path(__file__).parent.parent.parent
Expand Down

0 comments on commit e28230c

Please sign in to comment.