Skip to content

[rllib] Strange error thrown in DictFlatteningPreprocessor when trying to train with DQN algorithm on latest Nightly version of ray. #17568

Closed
@akshaygh0sh

Description

On the official 1.15.1 release, I did not see this error thrown. My agent trained perfectly fine using the PPO algorithm prior to switching to nightly (although on CPU, which was why I am now using the Nightly version that fixed the Torch GPU error) .

For reference, my agent runs in a custom environment with a custom network that has an action mask. The custom environment was written using the PettingZoo API as an AECEnvironment (and passes the API test) and then wrapped to an RLLIB environment using the PettingZooEnv() wrapper. The observation returned from the environment at each step is a dictionary (gym.Spaces.Dict) with keys "observation" and "action_mask" for each agent in my environment (in my case only 2 total agents).

After downloading the nightly version of ray, I get the following error thrown here when trying to train with the DQN Algorithm (but if I try to run using PPO I get the same "AttributeError: 'numpy.ndarray' object has no attribute 'items'" error) using tune.run():

2021-08-03 15:09:12,493 ERROR trial_runner.py:773 -- Trial DQN_ChessEnv_66744_00000: Error processing event.
Traceback (most recent call last):
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\tune\trial_runner.py", line 739, in _process_trial
results = self.trial_executor.fetch_result(trial)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\tune\ray_trial_executor.py", line 746, in fetch_result
result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray_private\client_mode_hook.py", line 82, in wrapper
return func(*args, **kwargs)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\worker.py", line 1611, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::DQN.train() (pid=13988, ip=10.0.0.37)
File "python\ray_raylet.pyx", line 535, in ray._raylet.execute_task
File "python\ray_raylet.pyx", line 485, in ray._raylet.execute_task.function_executor
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray_private\function_manager.py", line 563, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\agents\trainer.py", line 651, in train
raise e
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\agents\trainer.py", line 637, in train
result = Trainable.train(self)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\tune\trainable.py", line 237, in train
result = self.step()
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\agents\trainer_template.py", line 193, in step
res = next(self.train_exec_impl)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 756, in next
return next(self.built_iterator)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 843, in apply_filter
for item in it:
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 1075, in build_union
item = next(it)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 756, in next
return next(self.built_iterator)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\util\iter.py", line 783, in apply_foreach
for item in it:
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\execution\rollout_ops.py", line 75, in sampler
yield workers.local_worker().sample()
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\evaluation\rollout_worker.py", line 747, in sample
batches = [self.input_reader.next()]
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\evaluation\sampler.py", line 100, in next
batches = [self.get_data()]
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\evaluation\sampler.py", line 230, in get_data
item = next(self.rollout_provider)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\evaluation\sampler.py", line 610, in _env_runner
sample_collector=sample_collector,
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\evaluation\sampler.py", line 815, in _process_observations
policy_id).transform(raw_obs)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\models\preprocessors.py", line 259, in transform
self.write(observation, array, 0)
File "C:\Users\408aa\Anaconda3\envs\rl_env\lib\site-packages\ray\rllib\models\preprocessors.py", line 266, in write
observation = OrderedDict(sorted(observation.items()))
AttributeError: 'numpy.ndarray' object has no attribute 'items'

For reference, I am running the latest version of ray (nightly for Python 3.7), specifically to get the fix to the torch GPU from this pull request.

Could anyone tell me what exactly is happening here? Is it an issue with my environment? My config for the call to tune.run() is as follows:

config = dqn.DEFAULT_CONFIG.copy()

test_env = PettingZooEnv(env_creator())
obs_space = test_env.observation_space
act_space = test_env.action_space

config["multiagent"] = {
    "policies": {
        "player_0": (None, obs_space, act_space, {}),
        "player_1": (None, obs_space, act_space, {}),
    },
    "policy_mapping_fn": lambda agent_id: agent_id
}

config["num_gpus"] = 1
config["framework"] = "torch"
config["model"] = {
    "custom_model" : "CustomNetwork",
}
config["env"] = "CustomEnv"
config["horizon"] = 150
config["rollout_fragment_length"] = 500
config["callbacks"] = AlgebraicNotation
config["log_level"] = "INFO"
config["hiddens"] = []
config["dueling"] = False`

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    P0Issues that should be fixed in short orderenhancementRequest for new feature and/or capabilityrelease-blockerP0 Issue that blocks the releaserllibRLlib related issues

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions