Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
ExecutableWorld, designed to also work with BatchWorld (#170)
Browse files Browse the repository at this point in the history
* small

* exec world

* small

* blah

* mm

* index

* index

* index

* small batch fixes

* small batch fixes
  • Loading branch information
jaseweston authored Jun 27, 2017
1 parent 89b89cf commit 5f1d3ac
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 7 deletions.
2 changes: 2 additions & 0 deletions parlai/agents/repeat_label/repeat_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(self, opt, shared=None):

def act(self):
obs = self.observation
if obs is None:
return { 'text': "Nothing to repeat yet." }
reply = {}
reply['id'] = self.getID()
if ('labels' in obs and obs['labels'] is not None
Expand Down
76 changes: 69 additions & 7 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def shutdown(self):

class MultiAgentDialogWorld(World):
"""Basic world where each agent gets a turn in a round-robin fashion,
recieving as input the actions of all other agents since that agent last
receiving as input the actions of all other agents since that agent last
acted.
"""
def __init__(self, opt, agents=None, shared=None):
Expand Down Expand Up @@ -315,6 +315,53 @@ def shutdown(self):
a.shutdown()


class ExecutableWorld(MultiAgentDialogWorld):
"""A world where messages from agents can be interpreted as _actions_ in the
world which result in changes in the environment (are executed). Hence a grounded
simulation can be implemented rather than just dialogue between agents.
"""
def __init__(self, opt, agents=None, shared=None):
super().__init__(opt, agents, shared)
self.init_world()

def init_world(self):
"""An executable world class should implement this function, otherwise
the actions do not do anything (and it is the same as MultiAgentDialogWorld).
"""
pass

def execute(self, agent, act):
"""An executable world class should implement this function, otherwise
the actions do not do anything (and it is the same as MultiAgentDialogWorld).
"""
pass

def observe(self, agent, act):
"""An executable world class should implement this function, otherwise
the observations for each agent are just the messages from other agents
and not confitioned on the world at all (and it is thus the same as
MultiAgentDialogWorld). """
if agent.id == act['id']:
return None
else:
return act

def parley(self):
"""For each agent: act, execute and observe actions in world
"""
acts = self.acts
for index, agent in enumerate(self.agents):
# The agent acts.
acts[index] = agent.act()
# We execute this action in the world.
self.execute(agent, acts[index])
# All agents (might) observe the results.
for other_agent in self.agents:
obs = self.observe(other_agent, acts[index])
if obs is not None:
other_agent.observe(obs)


class MultiWorld(World):
"""Container for a set of worlds where each world gets a turn
in a round-robin fashion. The same user_agents are placed in each,
Expand Down Expand Up @@ -457,7 +504,7 @@ class BatchWorld(World):
"""Creates a separate world for each item in the batch, sharing
the parameters for each.
The underlying world(s) it is batching can be either ``DialogPartnerWorld``,
``MultiAgentWorld`` or ``MultiWorld``.
``MultiAgentWorld``, ``ExecutableWorld`` or ``MultiWorld``.
"""

def __init__(self, opt, world):
Expand All @@ -481,11 +528,20 @@ def __next__(self):
if self.epoch_done():
raise StopIteration()

def batch_observe(self, index, batch_actions):
def batch_observe(self, index, batch_actions, index_acting):
batch_observations = []
for i, w in enumerate(self.worlds):
agents = w.get_agents()
observation = agents[index].observe(validate(batch_actions[i]))
observation = None
if hasattr(w, 'observe'):
# The world has its own observe function, which the action
# first goes through (agents receive messages via the world,
# not from each other).
observation = w.observe(agents[index], validate(batch_actions[i]))
else:
if index == index_acting: return None # don't observe yourself talking
observation = validate(batch_actions[i])
observation = agents[index].observe(observation)
if observation is None:
raise ValueError('Agents should return what they observed.')
batch_observations.append(observation)
Expand Down Expand Up @@ -523,11 +579,17 @@ def parley(self):
w.parley_init()

for index in range(num_agents):
# The agent acts.
batch_act = self.batch_act(index, batch_observations[index])
# We possibly execute this action in the world.
for i, w in enumerate(self.worlds):
if hasattr(w, 'execute'):
w.execute(w.agents[i], batch_act[i])
# All agents (might) observe the results.
for other_index in range(num_agents):
if index != other_index:
batch_observations[other_index] = (
self.batch_observe(other_index, batch_act))
obs = self.batch_observe(other_index, batch_act, index)
if obs is not None:
batch_observations[other_index] = obs

def display(self):
s = ("[--batchsize " + str(len(self.worlds)) + "--]\n")
Expand Down

0 comments on commit 5f1d3ac

Please sign in to comment.