Skip to content

Commit

Permalink
PID code and Update Readme (facebookresearch#165)
Browse files Browse the repository at this point in the history
* clean PID implementation

* minor text changes

* make batch friendly, add tests

* lint tests

* make tests deterministic

* fix docstring

* add colab to readme

* clean PR for mbrl-lib
  • Loading branch information
Nathan Lambert authored Sep 12, 2022
1 parent 87008d0 commit 7dbca52
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ installation and are specific to models of type
We are planning to extend this in the future; if you have useful suggestions
don't hesitate to raise an issue or submit a pull request!

## Advanced Examples
MBRL-Lib can be used for many different research projects in the subject area.
We have support for the following projects:
* [Trajectory-based Dynamics Model](https://arxiv.org/abs/2012.09156) Training [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/natolambert/mbrl-lib-dev/blob/main/notebooks/traj_based_model.ipynb)

## Documentation
Please check out our **[documentation](https://facebookresearch.github.io/mbrl-lib/)**
and don't hesitate to raise issues or contribute if anything is unclear!
Expand Down
1 change: 1 addition & 0 deletions mbrl/planning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .core import Agent, RandomAgent, complete_agent_cfg, load_agent
from .linear_feedback import PIDAgent
from .trajectory_opt import (
CEMOptimizer,
ICEMOptimizer,
Expand Down
122 changes: 122 additions & 0 deletions mbrl/planning/linear_feedback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional

import numpy as np

from .core import Agent


class PIDAgent(Agent):
"""
Agent that reacts via an internal set of proportional–integral–derivative controllers.
A broad history of the PID controller can be found here:
https://en.wikipedia.org/wiki/PID_controller.
Args:
k_p (np.ndarry): proportional control coeff (Nx1)
k_i (np.ndarry): integral control coeff (Nx1)
k_d (np.ndarry): derivative control coeff (Nx1)
target (np.ndarry): setpoint (Nx1)
state_mapping (np.ndarry): indices of the state vector to apply the PID control to.
E.g. for a system with states [angle, angle_vel, position, position_vel], state_mapping
of [1, 3] and dim of 2 will apply the PID to angle_vel and position_vel variables.
batch_dim (int): number of samples to compute actions for simultaneously
"""

def __init__(
self,
k_p: np.ndarray,
k_i: np.ndarray,
k_d: np.ndarray,
target: np.ndarray,
state_mapping: Optional[np.ndarray] = None,
batch_dim: Optional[int] = 1,
):
super().__init__()
assert len(k_p) == len(k_i) == len(k_d) == len(target)
self.n_dof = len(k_p)

# State mapping defaults to first N states
if state_mapping is not None:
assert len(state_mapping) == len(target)
self.state_mapping = state_mapping
else:
self.state_mapping = np.arange(0, self.n_dof)

self.batch_dim = batch_dim

self._prev_error = np.zeros((self.n_dof, self.batch_dim))
self._cum_error = np.zeros((self.n_dof, self.batch_dim))

self.k_p = np.repeat(k_p[:, np.newaxis], self.batch_dim, axis=1)
self.k_i = np.repeat(k_i[:, np.newaxis], self.batch_dim, axis=1)
self.k_d = np.repeat(k_d[:, np.newaxis], self.batch_dim, axis=1)
self.target = np.repeat(target[:, np.newaxis], self.batch_dim, axis=1)

def act(self, obs: np.ndarray, **_kwargs) -> np.ndarray:
"""Issues an action given an observation.
This method optimizes a given observation or batch of observations for a
one-step action choice.
Args:
obs (np.ndarray): the observation for which the action is needed either N x 1 or N x B,
where N is the state dim and B is the batch size.
Returns:
(np.ndarray): the action outputted from the PID, either shape n_dof x 1 or n_dof x B.
"""
if obs.ndim == 1:
obs = np.expand_dims(obs, -1)
if len(obs) > self.n_dof:
pos = obs[self.state_mapping]
else:
pos = obs

error = self.target - pos
self._cum_error += error

P_value = np.multiply(self.k_p, error)
I_value = np.multiply(self.k_i, self._cum_error)
D_value = np.multiply(self.k_d, (error - self._prev_error))
self._prev_error = error
action = P_value + I_value + D_value
return action

def reset(self):
"""
Reset internal errors.
"""
self._prev_error = np.zeros((self.n_dof, self.batch_dim))
self._cum_error = np.zeros((self.n_dof, self.batch_dim))

def get_errors(self):
return self._prev_error, self._cum_error

def _get_P(self):
return self.k_p

def _get_I(self):
return self.k_i

def _get_D(self):
return self.k_d

def _get_targets(self):
return self.target

def get_parameters(self):
"""
Returns the parameters of the PID agent concatenated.
Returns:
(np.ndarray): the parameters.
"""
return np.stack(
(self._get_P(), self._get_I(), self._get_D(), self._get_targets())
).flatten()
81 changes: 81 additions & 0 deletions tests/core/test_planning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import pytest

import mbrl.planning as planning


def create_pid_agent(dim,
state_mapping=None,
batch_dim=1):
agent = planning.PIDAgent(k_p=np.ones(dim, ),
k_i=np.ones(dim, ),
k_d=np.ones(dim, ),
target=np.zeros(dim, ),
state_mapping=state_mapping,
batch_dim=batch_dim,
)
return agent


def test_pid_agent_one_dim():
"""
This test covers the creation of PID agents in the most basic form.
"""
pid = create_pid_agent(dim=1)
pid.reset()
init_obs = np.array([2.2408932])
act = pid.act(init_obs)

# check action computation
assert act == pytest.approx(-6.722, 0.1)

# check reset
pid.reset()
prev_error, cum_error = pid.get_errors()
assert np.sum(prev_error) == np.sum(cum_error) == 0


def test_pid_agent_multi_dim():
"""
This test covers regular updates for the multi-dim PID agent.
"""
pid = create_pid_agent(dim=2, state_mapping=np.array([1, 3]), )
init_obs = np.array([ 0.95008842, -0.15135721, -0.10321885, 0.4105985 ])
act1 = pid.act(init_obs)
next_obs = np.array([0.14404357, 1.45427351, 0.76103773, 0.12167502])
act2 = pid.act(next_obs)
assert act1 + act2 == pytest.approx([-3.908, -1.596], 0.1)

# check reset
pid.reset()
prev_error, cum_error = pid.get_errors()
assert np.sum(prev_error) == np.sum(cum_error) == 0


def test_pid_agent_batch(batch_dim=5):
"""
Tests the agent for batch-mode computation of actions.
"""
pid = create_pid_agent(dim=2, state_mapping=np.array([1, 3]), batch_dim=batch_dim)

init_obs = np.array([[ 0.95008842, -0.15135721, -0.10321885, 0.4105985 , 0.14404357],
[ 1.45427351, 0.76103773, 0.12167502, 0.44386323, 0.33367433],
[ 1.49407907, -0.20515826, 0.3130677 , -0.85409574, -2.55298982],
[ 0.6536186 , 0.8644362 , -0.74216502, 2.26975462, -1.45436567]])
act1 = pid.act(init_obs)
next_obs = np.array([[ 0.04575852, -0.18718385, 1.53277921, 1.46935877, 0.15494743],
[ 0.37816252, -0.88778575, -1.98079647, -0.34791215, 0.15634897],
[ 1.23029068, 1.20237985, -0.38732682, -0.30230275, -1.04855297],
[-1.42001794, -1.70627019, 1.9507754 , -0.50965218, -0.4380743 ]])
act2 = pid.act(next_obs)
assert (act1 + act2)[0] == pytest.approx([-5.497, 0.380, 5.577, -0.287, -1.470], 0.1)

# check reset
pid.reset()
prev_error, cum_error = pid.get_errors()
assert np.sum(prev_error) == np.sum(cum_error) == 0

0 comments on commit 7dbca52

Please sign in to comment.