Skip to content

Commit

Permalink
Add multi agent support in rollout.py (ray-project#4114)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoine-galataud authored and ericl committed Mar 2, 2019
1 parent 48f6cd3 commit 8288deb
Showing 1 changed file with 47 additions and 20 deletions.
67 changes: 47 additions & 20 deletions python/ray/rllib/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ def run(args, parser):
if not config:
# Load configuration from file
config_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(config_dir, "params.json")
config_path = os.path.join(config_dir, "params.pkl")
if not os.path.exists(config_path):
config_path = os.path.join(config_dir, "../params.json")
config_path = os.path.join(config_dir, "../params.pkl")
if not os.path.exists(config_path):
raise ValueError(
"Could not find params.json in either the checkpoint dir or "
"Could not find params.pkl in either the checkpoint dir or "
"its parent directory.")
with open(config_path) as f:
config = json.load(f)
with open(config_path, 'rb') as f:
config = pickle.load(f)
if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"])

Expand All @@ -102,18 +102,18 @@ def run(args, parser):
def rollout(agent, env_name, num_steps, out=None, no_render=True):
if hasattr(agent, "local_evaluator"):
env = agent.local_evaluator.env
multiagent = agent.local_evaluator.multiagent
if multiagent:
policy_agent_mapping = agent.config["multiagent"][
"policy_mapping_fn"]
mapping_cache = {}
policy_map = agent.local_evaluator.policy_map
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
else:
env = gym.make(env_name)

if hasattr(agent, "local_evaluator"):
state_init = agent.local_evaluator.policy_map[
"default"].get_initial_state()
else:
state_init = []
if state_init:
use_lstm = True
else:
use_lstm = False
multiagent = False
use_lstm = {'default': False}

if out is not None:
rollouts = []
Expand All @@ -125,13 +125,39 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
done = False
reward_total = 0.0
while not done and steps < (num_steps or steps + 1):
if use_lstm:
action, state_init, logits = agent.compute_action(
state, state=state_init)
if multiagent:
action_dict = {}
for agent_id in state.keys():
a_state = state[agent_id]
if a_state is not None:
policy_id = mapping_cache.setdefault(
agent_id, policy_agent_mapping(agent_id))
p_use_lstm = use_lstm[policy_id]
if p_use_lstm:
a_action, p_state_init, _ = agent.compute_action(
a_state,
state=state_init[policy_id],
policy_id=policy_id)
state_init[policy_id] = p_state_init
else:
a_action = agent.compute_action(
a_state, policy_id=policy_id)
action_dict[agent_id] = a_action
action = action_dict
else:
action = agent.compute_action(state)
if use_lstm["default"]:
action, state_init, _ = agent.compute_action(
state, state=state_init)
else:
action = agent.compute_action(state)

next_state, reward, done, _ = env.step(action)
reward_total += reward

if multiagent:
done = done["__all__"]
reward_total += sum(reward.values())
else:
reward_total += reward
if not no_render:
env.render()
if out is not None:
Expand All @@ -141,6 +167,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
if out is not None:
rollouts.append(rollout)
print("Episode reward", reward_total)

if out is not None:
pickle.dump(rollouts, open(out, "wb"))

Expand Down

0 comments on commit 8288deb

Please sign in to comment.