Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiagent model using concatenated observations #1416

Merged
merged 19 commits into from
Jan 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/ray/rllib/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# flake8: noqa
from ray.rllib.examples.multiagent_mountaincar_env \
import MultiAgentMountainCarEnv
from ray.rllib.examples.multiagent_pendulum_env \
import MultiAgentPendulumEnv
56 changes: 56 additions & 0 deletions python/ray/rllib/examples/multiagent_mountaincar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
""" Multiagent mountain car. Each agent outputs an action which
is summed to form the total action. This is a discrete
multiagent example
"""

import gym
from gym.envs.registration import register

import ray
import ray.rllib.ppo as ppo
from ray.tune.registry import get_registry, register_env

env_name = "MultiAgentMountainCarEnv"

env_version_num = 0
env_name = env_name + '-v' + str(env_version_num)


def pass_params_to_gym(env_name):
global env_version_num

register(
id=env_name,
entry_point='ray.rllib.examples:' + "MultiAgentMountainCarEnv",
max_episode_steps=200,
kwargs={}
)


def create_env(env_config):
pass_params_to_gym(env_name)
env = gym.envs.make(env_name)
return env


if __name__ == '__main__':
register_env(env_name, lambda env_config: create_env(env_config))
config = ppo.DEFAULT_CONFIG.copy()
horizon = 200
num_cpus = 2
ray.init(num_cpus=num_cpus, redirect_output=False)
config["num_workers"] = num_cpus
config["timesteps_per_batch"] = 100
config["num_sgd_iter"] = 10
config["gamma"] = 0.999
config["horizon"] = horizon
config["use_gae"] = True
config["model"].update({"fcnet_hiddens": [256, 256]})
options = {"multiagent_obs_shapes": [2, 2],
"multiagent_act_shapes": [3, 3],
"multiagent_shared_model": False,
"multiagent_fcnet_hiddens": [[32, 32]] * 2}
config["model"].update({"custom_options": options})
alg = ppo.PPOAgent(env=env_name, registry=get_registry(), config=config)
for i in range(1):
alg.train()
52 changes: 52 additions & 0 deletions python/ray/rllib/examples/multiagent_mountaincar_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import math
from gym.spaces import Box, Tuple, Discrete
import numpy as np
from gym.envs.classic_control.mountain_car import MountainCarEnv

"""
Multiagent mountain car that sums and then
averages its actions to produce the velocity
"""


class MultiAgentMountainCarEnv(MountainCarEnv):
def __init__(self):
self.min_position = -1.2
self.max_position = 0.6
self.max_speed = 0.07
self.goal_position = 0.5

self.low = np.array([self.min_position, -self.max_speed])
self.high = np.array([self.max_position, self.max_speed])

self.viewer = None

self.action_space = [Discrete(3) for _ in range(2)]
self.observation_space = Tuple(tuple(Box(self.low, self.high)
for _ in range(2)))

self._seed()
self.reset()

def _step(self, action):
summed_act = 0.5 * np.sum(action)

position, velocity = self.state
velocity += (summed_act - 1) * 0.001
velocity += math.cos(3 * position) * (-0.0025)
velocity = np.clip(velocity, -self.max_speed, self.max_speed)
position += velocity
position = np.clip(position, self.min_position, self.max_position)
if (position == self.min_position and velocity < 0):
velocity = 0

done = bool(position >= self.goal_position)

reward = position

self.state = (position, velocity)
return [np.array(self.state) for _ in range(2)], reward, done, {}

def _reset(self):
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
return [np.array(self.state) for _ in range(2)]
56 changes: 56 additions & 0 deletions python/ray/rllib/examples/multiagent_pendulum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
""" Run script for multiagent pendulum env. Each agent outputs a
torque which is summed to form the total torque. This is a
continuous multiagent example
"""

import gym
from gym.envs.registration import register

import ray
import ray.rllib.ppo as ppo
from ray.tune.registry import get_registry, register_env

env_name = "MultiAgentPendulumEnv"

env_version_num = 0
env_name = env_name + '-v' + str(env_version_num)


def pass_params_to_gym(env_name):
global env_version_num

register(
id=env_name,
entry_point='ray.rllib.examples:' + "MultiAgentPendulumEnv",
max_episode_steps=100,
kwargs={}
)


def create_env(env_config):
pass_params_to_gym(env_name)
env = gym.envs.make(env_name)
return env


if __name__ == '__main__':
register_env(env_name, lambda env_config: create_env(env_config))
config = ppo.DEFAULT_CONFIG.copy()
horizon = 100
num_cpus = 2
ray.init(num_cpus=num_cpus, redirect_output=False)
config["num_workers"] = num_cpus
config["timesteps_per_batch"] = 100
config["num_sgd_iter"] = 10
config["gamma"] = 0.999
config["horizon"] = horizon
config["use_gae"] = True
config["model"].update({"fcnet_hiddens": [256, 256]})
options = {"multiagent_obs_shapes": [3, 3],
"multiagent_act_shapes": [1, 1],
"multiagent_shared_model": True,
"multiagent_fcnet_hiddens": [[32, 32]] * 2}
config["model"].update({"custom_options": options})
alg = ppo.PPOAgent(env=env_name, registry=get_registry(), config=config)
for i in range(1):
alg.train()
70 changes: 70 additions & 0 deletions python/ray/rllib/examples/multiagent_pendulum_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from gym.spaces import Box, Tuple
from gym.utils import seeding
from gym.envs.classic_control.pendulum import PendulumEnv
import numpy as np

"""
Multiagent pendulum that sums its torques to generate an action
"""


class MultiAgentPendulumEnv(PendulumEnv):
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 30
}

def __init__(self):
self.max_speed = 8
self.max_torque = 2.
self.dt = .05
self.viewer = None

high = np.array([1., 1., self.max_speed])
self.action_space = [Box(low=-self.max_torque / 2,
high=self.max_torque / 2, shape=(1,))
for _ in range(2)]
self.observation_space = Tuple(tuple(Box(low=-high, high=high)
for _ in range(2)))

self._seed()

def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]

def _step(self, u):
th, thdot = self.state # th := theta

summed_u = np.sum(u)
g = 10.
m = 1.
length = 1.
dt = self.dt

summed_u = np.clip(summed_u, -self.max_torque, self.max_torque)
self.last_u = summed_u # for rendering
costs = self.angle_normalize(th) ** 2 + .1 * thdot ** 2 + \
.001 * (summed_u ** 2)

newthdot = thdot + (-3 * g / (2 * length) * np.sin(th + np.pi) +
3. / (m * length ** 2) * summed_u) * dt
newth = th + newthdot * dt
newthdot = np.clip(newthdot, -self.max_speed, self.max_speed)

self.state = np.array([newth, newthdot])
return self._get_obs(), -costs, False, {}

def _reset(self):
high = np.array([np.pi, 1])
self.state = self.np_random.uniform(low=-high, high=high)
self.last_u = None
return self._get_obs()

def _get_obs(self):
theta, thetadot = self.state
return [np.array([np.cos(theta), np.sin(theta), thetadot])
for _ in range(2)]

def angle_normalize(self, x):
return (((x + np.pi) % (2 * np.pi)) - np.pi)
4 changes: 3 additions & 1 deletion python/ray/rllib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.convnet import ConvolutionalNetwork
from ray.rllib.models.lstm import LSTM
from ray.rllib.models.multiagentfcnet import MultiAgentFullyConnectedNetwork


__all__ = ["ActionDistribution", "ActionDistribution", "Categorical",
"DiagGaussian", "Deterministic", "ModelCatalog", "Model",
"FullyConnectedNetwork", "ConvolutionalNetwork", "LSTM"]
"FullyConnectedNetwork", "ConvolutionalNetwork", "LSTM",
"MultiAgentFullyConnectedNetwork"]
47 changes: 47 additions & 0 deletions python/ray/rllib/models/action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import tensorflow as tf
import numpy as np
from ray.rllib.utils.reshaper import Reshaper


class ActionDistribution(object):
Expand Down Expand Up @@ -109,3 +110,49 @@ class Deterministic(ActionDistribution):

def sample(self):
return self.inputs


class MultiActionDistribution(ActionDistribution):
"""Action distribution that operates for list of actions.

Args:
inputs (Tensor list): A list of tensors from which to compute samples.
"""
def __init__(self, inputs, action_space, child_distributions):
# you actually have to instantiate the child distributions
self.reshaper = Reshaper(action_space)
split_inputs = self.reshaper.split_tensor(inputs)
child_list = []
for i, distribution in enumerate(child_distributions):
child_list.append(distribution(split_inputs[i]))
self.child_distributions = child_list

def logp(self, x):
"""The log-likelihood of the action distribution."""
split_list = self.reshaper.split_tensor(x)
for i, distribution in enumerate(self.child_distributions):
# Remove extra categorical dimension
if isinstance(distribution, Categorical):
split_list[i] = tf.squeeze(split_list[i], axis=-1)
log_list = np.asarray([distribution.logp(split_x) for
distribution, split_x in
zip(self.child_distributions, split_list)])
return np.sum(log_list)

def kl(self, other):
"""The KL-divergence between two action distributions."""
kl_list = np.asarray([distribution.kl(other_distribution) for
distribution, other_distribution in
zip(self.child_distributions,
other.child_distributions)])
return np.sum(kl_list)

def entropy(self):
"""The entropy of the action distribution."""
entropy_list = np.array([s.entropy() for s in
self.child_distributions])
return np.sum(entropy_list)

def sample(self):
"""Draw a sample from the action distribution."""
return [[s.sample() for s in self.child_distributions]]
Loading