Skip to content

[rllib] Memory leak in environment worker in multi-agent setup #9964

@sergeivolodin

Description

@sergeivolodin

What is the problem?

When training in a multi-agent environment using multiple environment workers, the memory of the workers increases constantly and is not released after the policy updates.

image

== Status ==
Memory usage on this node: 5.2/8.7 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/10 CPUs, 0/0 GPUs, 0.0/5.03 GiB heap, 0.0/1.71 GiB objects
Result logdir: /home/sergei/ray_results/dummy_run
Number of trials: 1 (1 ERROR)
+------------------------------------+----------+-------+--------+------------------+---------+----------+
| Trial name                         | status   | loc   |   iter |   total time (s) |      ts |   reward |
|------------------------------------+----------+-------+--------+------------------+---------+----------|
| PPO_DummyMultiAgentEnv_fbe89_00000 | ERROR    |       |     57 |          4463.08 | 1881000 |        0 |
+------------------------------------+----------+-------+--------+------------------+---------+----------+
Number of errored trials: 1
+------------------------------------+--------------+---------------------------------------------------------------------------------------------------+
| Trial name                         |   # failures | error file                                                                                        |
|------------------------------------+--------------+---------------------------------------------------------------------------------------------------|
| PPO_DummyMultiAgentEnv_fbe89_00000 |            1 | /home/sergei/ray_results/dummy_run/PPO_DummyMultiAgentEnv_0_2020-08-05_16-21-42pm5cl_th/error.txt |
+------------------------------------+--------------+---------------------------------------------------------------------------------------------------+
ray.exceptions.RayTaskError(RayOutOfMemoryError): ray::RolloutWorker.par_iter_next() (pid=7604, ip=192.168.175.153)
  File "python/ray/_raylet.pyx", line 408, in ray._raylet.execute_task
  File "/home/sergei/miniconda3/envs/fresh_ray/lib/python3.8/site-packages/ray/memory_monitor.py", line 137, in raise_if_low_memory
    raise RayOutOfMemoryError(
ray.memory_monitor.RayOutOfMemoryError: Heap memory usage for ray_RolloutWorker_7604 is 0.4395 / 0.4395 GiB limit

If no memory limit is set, the processes run out of system memory and are killed. If the memory_per_worker limit is set, they go past the limit and are killed.

Ray version and other system information (Python version, TensorFlow version, OS):

Name Value
Ray version ray==0.8.6
Python version Python 3.8.5 (default, Aug 5 2020, 08:36:46)
TensorFlow version tensorflow==2.3.0
OS Ubuntu 18.04.4 LTS (GNU/Linux 4.19.121-microsoft-standard x86_64) (same thing in non-wsl as well)

Reproduction

Run this and measure memory consumption. If you remove memory_per_worker limits, it will take longer, as workers will try to consume all system memory.

import numpy as np
import ray
import tensorflow as tf
from ray import tune
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
from ray.rllib.env.multi_agent_env import MultiAgentEnv
import numpy as np
from gym.spaces import Box
from ray.tune.registry import register_env


def dim_to_gym_box(dim, val=np.inf):
    """Create gym.Box with specified dimension."""
    high = np.full((dim,), fill_value=val)
    return Box(low=-high, high=high)


class DummyMultiAgentEnv(MultiAgentEnv):
    """Return zero observations."""

    def __init__(self, config):
        del config  # Unused
        super(DummyMultiAgentEnv, self).__init__()
        self.config = dict(act_dim=17, obs_dim=380, n_players=2, n_steps=1000)
        self.players = ["player_%d" % p for p in range(self.config['n_players'])]
        self.current_step = 0

    def _obs(self):
        return np.zeros((self.config['obs_dim'],))

    def reset(self):
        self.current_step = 0
        return {p: self._obs() for p in self.players}

    def step(self, action_dict):
        done = self.current_step >= self.config['n_steps']
        self.current_step += 1

        obs = {p: self._obs() for p in self.players}
        rew = {p: 0.0 for p in self.players}
        dones = {p: done for p in self.players + ["__all__"]}
        infos = {p: {} for p in self.players}

        return obs, rew, dones, infos

    @property
    def observation_space(self):
        return dim_to_gym_box(self.config['obs_dim'])

    @property
    def action_space(self):
        return dim_to_gym_box(self.config['act_dim'])


def create_env(config):
    """Create the dummy environment."""
    return DummyMultiAgentEnv(config)


env_name = "DummyMultiAgentEnv"
register_env(env_name, create_env)


def get_trainer_config(env_config, train_policies, num_workers=5, framework="tfe"):
    """Build configuration for 1 run."""

    # obtaining parameters from the environment
    env = create_env(env_config)
    act_space = env.action_space
    obs_space = env.observation_space
    players = env.players
    del env

    def get_policy_config(p):
        """Get policy configuration for a player."""
        return (PPOTFPolicy, obs_space, act_space, {
            "model": {
                "use_lstm": False,
                "fcnet_hiddens": [64, 64],
            },
            "framework": framework,
        })

    # creating policies
    policies = {p: get_policy_config(p) for p in players}

    # trainer config
    config = {
        "env": env_name, "env_config": env_config, "num_workers": num_workers,
        "multiagent": {"policy_mapping_fn": lambda x: x, "policies": policies,
                       "policies_to_train": train_policies},
        "framework": framework,
        "train_batch_size": 32768,
        "num_sgd_iter": 1,
        "sgd_minibatch_size": 32768,

        # 450 megabytes for each worker (enough for 1 iteration)
        "memory_per_worker": 450 * 1024 * 1024,
        "object_store_memory_per_worker": 128 * 1024 * 1024,
    }
    return config


def tune_run():
    ray.init(ignore_reinit_error=True)
    config = get_trainer_config(train_policies=['player_1', 'player_2'], env_config={})
    return tune.run("PPO", config=config, verbose=True, name="dummy_run", num_samples=1)


if __name__ == '__main__':
    tune_run()
  • I have verified my script runs in a clean environment and reproduces the issue.
  • I have verified the issue also occurs with the latest wheels.

Metadata

Metadata

Assignees

No one assigned

    Labels

    P2Important issue, but not time-criticalbugSomething that is supposed to be working; but isn't

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions