Skip to content

Commit

Permalink
Fix the demonstrate experiment and backport
Browse files Browse the repository at this point in the history
  • Loading branch information
ntoxeg committed May 25, 2022
1 parent a22d636 commit 9254cb7
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 36 deletions.
56 changes: 40 additions & 16 deletions experiments/zelda/demonstrate/nars_zelda.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,14 @@ def abs_to_rel(avatar, op):
return ["^rotate_right"] * dor + ["^move_forwards"]


def last_avatar_event(history: list[dict]) -> Optional[dict]:
"""Return the last avatar event in the history"""
for event in reversed(history):
if event["SourceObjectName"] == "avatar":
return event
return None


def object_reached(obj_type: str, env_state: dict, info: dict) -> bool:
"""Check if an object has been reached
Expand All @@ -337,16 +345,14 @@ def object_reached(obj_type: str, env_state: dict, info: dict) -> bool:
if len(history) == 0:
return False

last_event = history[-1]
if (
last_event["SourceObjectName"] == "avatar"
and last_event["DestinationObjectName"] == obj_type
):
send_input(
agent.process,
nal_now(f"<{ext(last_event['DestinationObjectName'])} --> [reached]>"),
)
return True
last_avent = last_avatar_event(history)
if last_avent is not None:
if last_avent["DestinationObjectName"] == obj_type:
send_input(
agent.process,
nal_now(f"<{ext(last_avent['DestinationObjectName'])} --> [reached]>"),
)
return True

return False

Expand Down Expand Up @@ -389,15 +395,27 @@ def demo_reach_key(symbol: str) -> None:
gym_actions = agent.determine_actions(
{"executions": [{"operator": action, "arguments": []}]}
)
_, _, done, _ = agent.env.step(gym_actions[0])
send_observation(agent.process, agent.env.get_state()) # type: ignore
_, reward, done, info = agent.env.step(gym_actions[0])
agent.observe()

env_state = agent.env.get_state() # type: ignore
env_state["reward"] = reward

satisfied_goals = [g.satisfied(env_state, info) for g in goals]
for g, sat in zip(goals, satisfied_goals):
if sat:
print(f"{g.symbol} satisfied.")
send_input(agent.process, nal_now(g.symbol))
get_raw_output(agent.process)

env.render(observer="global") # type: ignore
if g.symbol == key_goal_sym:
agent.has_key = True

agent.env.render(observer="global") # type: ignore
sleep(1)
if done:
env.reset()
agent.reset()
demo_reach_key(symbol)
send_input(agent.process, nal_now(symbol))


def make_loc_goal(process: pexpect.spawn, pos, goal_symbol):
Expand Down Expand Up @@ -461,9 +479,15 @@ def key_check(_, info) -> bool:
think_ticks=10,
background_knowledge=background_knowledge,
)

# DEMONSTRATE
agent.reset()
print("Demonstration: reaching a key...")
demo_reach_key(KEY_GOAL.symbol)

total_reward = 0.0
episode_reward = 0.0
tb_writer = SummaryWriter(comment="-nars-zelda")
tb_writer = SummaryWriter(comment="-nars-zelda-demonstrate")
done = False
# TRAINING LOOP
for episode in range(NUM_EPISODES):
Expand Down
36 changes: 16 additions & 20 deletions experiments/zelda/main/nars_zelda.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,35 +318,31 @@ def abs_to_rel(avatar, op):
return ["^rotate_right"] * dor + ["^move_forwards"]


def last_avatar_event(history: list[dict]) -> Optional[dict]:
"""Return the last avatar event in the history"""
for event in reversed(history):
if event["SourceObjectName"] == "avatar":
return event
return None


def object_reached(obj_type: str, env_state: dict, info: dict) -> bool:
"""Check if an object has been reached
Assumes that if the object does not exist, then it must have been reached.
"""
# try:
# avatar = next(obj for obj in env_state["Objects"] if obj["Name"] == "avatar")
# except StopIteration:
# ic("No avatar found. Goal unsatisfiable.")
# return False
# try:
# target = next(obj for obj in env_state["Objects"] if obj["Name"] == obj_type)
# except StopIteration:
# return True
# return avatar["Location"] == target["Location"]
history = info["History"]
if len(history) == 0:
return False

last_event = history[-1]
if (
last_event["SourceObjectName"] == "avatar"
and last_event["DestinationObjectName"] == obj_type
):
send_input(
agent.process,
nal_now(f"<{ext(last_event['DestinationObjectName'])} --> [reached]>"),
)
return True
last_avent = last_avatar_event(history)
if last_avent is not None:
if last_avent["DestinationObjectName"] == obj_type:
send_input(
agent.process,
nal_now(f"<{ext(last_avent['DestinationObjectName'])} --> [reached]>"),
)
return True

return False

Expand Down

0 comments on commit 9254cb7

Please sign in to comment.