From 9254cb78129391f8f173e321a720d0add7cff2b2 Mon Sep 17 00:00:00 2001 From: Adrian Borucki Date: Wed, 25 May 2022 14:56:03 +0200 Subject: [PATCH] Fix the demonstrate experiment and backport --- experiments/zelda/demonstrate/nars_zelda.py | 56 +++++++++++++++------ experiments/zelda/main/nars_zelda.py | 36 ++++++------- 2 files changed, 56 insertions(+), 36 deletions(-) diff --git a/experiments/zelda/demonstrate/nars_zelda.py b/experiments/zelda/demonstrate/nars_zelda.py index 151714a..7ca3c9a 100644 --- a/experiments/zelda/demonstrate/nars_zelda.py +++ b/experiments/zelda/demonstrate/nars_zelda.py @@ -318,6 +318,14 @@ def abs_to_rel(avatar, op): return ["^rotate_right"] * dor + ["^move_forwards"] +def last_avatar_event(history: list[dict]) -> Optional[dict]: + """Return the last avatar event in the history""" + for event in reversed(history): + if event["SourceObjectName"] == "avatar": + return event + return None + + def object_reached(obj_type: str, env_state: dict, info: dict) -> bool: """Check if an object has been reached @@ -337,16 +345,14 @@ def object_reached(obj_type: str, env_state: dict, info: dict) -> bool: if len(history) == 0: return False - last_event = history[-1] - if ( - last_event["SourceObjectName"] == "avatar" - and last_event["DestinationObjectName"] == obj_type - ): - send_input( - agent.process, - nal_now(f"<{ext(last_event['DestinationObjectName'])} --> [reached]>"), - ) - return True + last_avent = last_avatar_event(history) + if last_avent is not None: + if last_avent["DestinationObjectName"] == obj_type: + send_input( + agent.process, + nal_now(f"<{ext(last_avent['DestinationObjectName'])} --> [reached]>"), + ) + return True return False @@ -389,15 +395,27 @@ def demo_reach_key(symbol: str) -> None: gym_actions = agent.determine_actions( {"executions": [{"operator": action, "arguments": []}]} ) - _, _, done, _ = agent.env.step(gym_actions[0]) - send_observation(agent.process, agent.env.get_state()) # type: ignore + _, reward, done, info = agent.env.step(gym_actions[0]) + agent.observe() + + env_state = agent.env.get_state() # type: ignore + env_state["reward"] = reward + + satisfied_goals = [g.satisfied(env_state, info) for g in goals] + for g, sat in zip(goals, satisfied_goals): + if sat: + print(f"{g.symbol} satisfied.") + send_input(agent.process, nal_now(g.symbol)) + get_raw_output(agent.process) - env.render(observer="global") # type: ignore + if g.symbol == key_goal_sym: + agent.has_key = True + + agent.env.render(observer="global") # type: ignore sleep(1) if done: - env.reset() + agent.reset() demo_reach_key(symbol) - send_input(agent.process, nal_now(symbol)) def make_loc_goal(process: pexpect.spawn, pos, goal_symbol): @@ -461,9 +479,15 @@ def key_check(_, info) -> bool: think_ticks=10, background_knowledge=background_knowledge, ) + + # DEMONSTRATE + agent.reset() + print("Demonstration: reaching a key...") + demo_reach_key(KEY_GOAL.symbol) + total_reward = 0.0 episode_reward = 0.0 - tb_writer = SummaryWriter(comment="-nars-zelda") + tb_writer = SummaryWriter(comment="-nars-zelda-demonstrate") done = False # TRAINING LOOP for episode in range(NUM_EPISODES): diff --git a/experiments/zelda/main/nars_zelda.py b/experiments/zelda/main/nars_zelda.py index 151714a..29de7ee 100644 --- a/experiments/zelda/main/nars_zelda.py +++ b/experiments/zelda/main/nars_zelda.py @@ -318,35 +318,31 @@ def abs_to_rel(avatar, op): return ["^rotate_right"] * dor + ["^move_forwards"] +def last_avatar_event(history: list[dict]) -> Optional[dict]: + """Return the last avatar event in the history""" + for event in reversed(history): + if event["SourceObjectName"] == "avatar": + return event + return None + + def object_reached(obj_type: str, env_state: dict, info: dict) -> bool: """Check if an object has been reached Assumes that if the object does not exist, then it must have been reached. """ - # try: - # avatar = next(obj for obj in env_state["Objects"] if obj["Name"] == "avatar") - # except StopIteration: - # ic("No avatar found. Goal unsatisfiable.") - # return False - # try: - # target = next(obj for obj in env_state["Objects"] if obj["Name"] == obj_type) - # except StopIteration: - # return True - # return avatar["Location"] == target["Location"] history = info["History"] if len(history) == 0: return False - last_event = history[-1] - if ( - last_event["SourceObjectName"] == "avatar" - and last_event["DestinationObjectName"] == obj_type - ): - send_input( - agent.process, - nal_now(f"<{ext(last_event['DestinationObjectName'])} --> [reached]>"), - ) - return True + last_avent = last_avatar_event(history) + if last_avent is not None: + if last_avent["DestinationObjectName"] == obj_type: + send_input( + agent.process, + nal_now(f"<{ext(last_avent['DestinationObjectName'])} --> [reached]>"), + ) + return True return False